Skip to content

Commit

Permalink
Merge pull request #730 from NatLibFi/update-dependencies-v1.0-keras-…
Browse files Browse the repository at this point in the history
…save-format

Switch to Keras v3 save format for nn_ensemble
  • Loading branch information
osma authored Aug 16, 2023
2 parents 1c30cd5 + 1c8bc48 commit 88a19bf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
11 changes: 6 additions & 5 deletions annif/backend/nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from typing import TYPE_CHECKING, Any

import joblib
import keras.backend as K
import lmdb
import numpy as np
import tensorflow.keras.backend as K
from keras.layers import Add, Dense, Dropout, Flatten, Input, Layer
from keras.models import Model
from keras.saving import load_model
from keras.utils import Sequence
from scipy.sparse import csc_matrix, csr_matrix
from tensorflow.keras.layers import Add, Dense, Dropout, Flatten, Input, Layer
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import Sequence

import annif.corpus
import annif.parallel
Expand Down Expand Up @@ -97,7 +98,7 @@ class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBacke

name = "nn_ensemble"

MODEL_FILE = "nn-model.h5"
MODEL_FILE = "nn-model.keras"
LMDB_FILE = "nn-train.mdb"

DEFAULT_PARAMETERS = {
Expand Down
12 changes: 6 additions & 6 deletions tests/test_backend_nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def test_nn_ensemble_train_and_learn(registry, tmpdir):
assert nn_ensemble._model.optimizer.learning_rate.value() == 0.001

datadir = py.path.local(project.datadir)
assert datadir.join("nn-model.h5").exists()
assert datadir.join("nn-model.h5").size() > 0
assert datadir.join("nn-model.keras").exists()
assert datadir.join("nn-model.keras").size() > 0

# test online learning
modelfile = datadir.join("nn-model.h5")
modelfile = datadir.join("nn-model.keras")

old_size = modelfile.size()
old_mtime = modelfile.mtime()
Expand All @@ -129,7 +129,7 @@ def test_nn_ensemble_train_cached(registry):
datadir = py.path.local(project.datadir)
assert datadir.join("nn-train.mdb").exists()

datadir.join("nn-model.h5").remove()
datadir.join("nn-model.keras").remove()

nn_ensemble_type = annif.backend.get_backend("nn_ensemble")
nn_ensemble = nn_ensemble_type(
Expand All @@ -140,8 +140,8 @@ def test_nn_ensemble_train_cached(registry):

nn_ensemble.train("cached")

assert datadir.join("nn-model.h5").exists()
assert datadir.join("nn-model.h5").size() > 0
assert datadir.join("nn-model.keras").exists()
assert datadir.join("nn-model.keras").size() > 0


def test_nn_ensemble_train_and_learn_params(registry, tmpdir, capfd):
Expand Down

0 comments on commit 88a19bf

Please sign in to comment.