From 4886e08ccb7adeefabb59badf1c45bdb8ba47c74 Mon Sep 17 00:00:00 2001 From: Osma Suominen Date: Wed, 16 Aug 2023 11:15:05 +0300 Subject: [PATCH] Switch to Keras v3 format for nn_ensemble (with legacy h5 support) --- annif/backend/nn_ensemble.py | 16 +++++++--- tests/test_backend_nn_ensemble.py | 53 +++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/annif/backend/nn_ensemble.py b/annif/backend/nn_ensemble.py index 169eb8234..f24f1cc1f 100644 --- a/annif/backend/nn_ensemble.py +++ b/annif/backend/nn_ensemble.py @@ -97,7 +97,8 @@ class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBacke name = "nn_ensemble" - MODEL_FILE = "nn-model.h5" + MODEL_FILE = "nn-model.keras" + MODEL_FILE_FALLBACK = "nn-model.h5" LMDB_FILE = "nn-train.mdb" DEFAULT_PARAMETERS = { @@ -122,11 +123,18 @@ def initialize(self, parallel: bool = False) -> None: return model_filename = os.path.join(self.datadir, self.MODEL_FILE) if not os.path.exists(model_filename): - raise NotInitializedException( - "model file {} not found".format(model_filename), - backend_id=self.backend_id, + model_filename_fallback = os.path.join( + self.datadir, self.MODEL_FILE_FALLBACK ) + if os.path.exists(model_filename_fallback): + model_filename = model_filename_fallback + else: + raise NotInitializedException( + "model file {} not found".format(model_filename), + backend_id=self.backend_id, + ) self.debug("loading Keras model from {}".format(model_filename)) + print("loading Keras model from {}".format(model_filename)) self._model = load_model( model_filename, custom_objects={"MeanLayer": MeanLayer} ) diff --git a/tests/test_backend_nn_ensemble.py b/tests/test_backend_nn_ensemble.py index 1941e8665..71d1b417e 100644 --- a/tests/test_backend_nn_ensemble.py +++ b/tests/test_backend_nn_ensemble.py @@ -105,11 +105,11 @@ def test_nn_ensemble_train_and_learn(registry, tmpdir): assert nn_ensemble._model.optimizer.learning_rate.value() == 0.001 datadir = py.path.local(project.datadir) - assert datadir.join("nn-model.h5").exists() - assert datadir.join("nn-model.h5").size() > 0 + assert datadir.join("nn-model.keras").exists() + assert datadir.join("nn-model.keras").size() > 0 # test online learning - modelfile = datadir.join("nn-model.h5") + modelfile = datadir.join("nn-model.keras") old_size = modelfile.size() old_mtime = modelfile.mtime() @@ -129,7 +129,7 @@ def test_nn_ensemble_train_cached(registry): datadir = py.path.local(project.datadir) assert datadir.join("nn-train.mdb").exists() - datadir.join("nn-model.h5").remove() + datadir.join("nn-model.keras").remove() nn_ensemble_type = annif.backend.get_backend("nn_ensemble") nn_ensemble = nn_ensemble_type( @@ -140,8 +140,8 @@ def test_nn_ensemble_train_cached(registry): nn_ensemble.train("cached") - assert datadir.join("nn-model.h5").exists() - assert datadir.join("nn-model.h5").size() > 0 + assert datadir.join("nn-model.keras").exists() + assert datadir.join("nn-model.keras").size() > 0 def test_nn_ensemble_train_and_learn_params(registry, tmpdir, capfd): @@ -254,3 +254,44 @@ def test_nn_ensemble_suggest(app_project): assert nn_ensemble._model is not None assert len(results) > 0 + + +def test_nn_ensemble_h5_fallback(registry, app_project, monkeypatch): + nn_ensemble_type = annif.backend.get_backend("nn_ensemble") + nn_ensemble = nn_ensemble_type( + backend_id="nn_ensemble", + config_params={"sources": "dummy-en"}, + project=app_project, + ) + + datadir = py.path.local(app_project.datadir) + + # train model from cached data and save it in legacy HDF5 format + with monkeypatch.context() as m: + m.setattr(nn_ensemble, "MODEL_FILE", "nn-model.h5") + nn_ensemble.train("cached") + + # remove any existing .keras model + datadir.join("nn-model.keras").remove() + + # check that the trained model was saved in a HDF5 file + assert datadir.join("nn-model.h5").exists() + assert datadir.join("nn-model.h5").size() > 0 + + # delete the model, so it has to be loaded again + nn_ensemble._model = None + + # now try using suggest, which forces loading it from .h5 file + results = nn_ensemble.suggest( + [ + """Arkeologiaa sanotaan joskus myös + muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen + tiede tai oikeammin joukko tieteitä, jotka tutkivat ihmisen + menneisyyttä. Tutkimusta tehdään analysoimalla muinaisjäännöksiä + eli niitä jälkiä, joita ihmisten toiminta on jättänyt maaperään + tai vesistöjen pohjaan.""" + ] + )[0] + + assert nn_ensemble._model is not None + assert len(results) > 0