From 5a1c4d95368bc5ccdb457242137a59e1bbe82f84 Mon Sep 17 00:00:00 2001 From: lettercode <59030475+lettercode@users.noreply.github.com> Date: Tue, 23 Jul 2024 21:15:07 +0200 Subject: [PATCH] Fix bug in MixedMemory when concatenating list of tensors --- .github/workflows/python-test.yml | 2 +- ncps/keras/cfc_cell.py | 62 ++++++++++++------------------- ncps/keras/mm_rnn.py | 5 ++- ncps/keras/wired_cfc_cell.py | 14 ++++--- ncps/tests/test_keras.py | 45 +++++++++++++++++++--- 5 files changed, 76 insertions(+), 52 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 97b3305..f85524a 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -44,7 +44,7 @@ jobs: runs-on: ubuntu-latest container: - image: tensorflow/tensorflow:2.16.1 + image: tensorflow/tensorflow:2.17.0 env: KERAS_BACKEND: tensorflow volumes: diff --git a/ncps/keras/cfc_cell.py b/ncps/keras/cfc_cell.py index d12c9c4..2871d61 100644 --- a/ncps/keras/cfc_cell.py +++ b/ncps/keras/cfc_cell.py @@ -99,17 +99,18 @@ def build(self, input_shape): else: cat_shape = int(self.state_size + input_dim) + self.ff1_kernel = self.add_weight( + shape=(cat_shape, self.state_size), + initializer="glorot_uniform", + name="ff1_weight", + ) + self.ff1_bias = self.add_weight( + shape=(self.state_size,), + initializer="zeros", + name="ff1_bias", + ) + if self.mode == "pure": - self.ff1_kernel = self.add_weight( - shape=(cat_shape, self.state_size), - initializer="glorot_uniform", - name="ff1_weight", - ) - self.ff1_bias = self.add_weight( - shape=(self.state_size,), - initializer="zeros", - name="ff1_bias", - ) self.w_tau = self.add_weight( shape=(1, self.state_size), initializer=keras.initializers.Zeros(), @@ -121,16 +122,6 @@ def build(self, input_shape): name="A", ) else: - self.ff1_kernel = self.add_weight( - shape=(cat_shape, self.state_size), - initializer="glorot_uniform", - name="ff1_weight", - ) - self.ff1_bias = self.add_weight( - shape=(self.state_size,), - initializer="zeros", - name="ff1_bias", - ) self.ff2_kernel = self.add_weight( shape=(cat_shape, self.state_size), initializer="glorot_uniform", @@ -142,15 +133,6 @@ def build(self, input_shape): name="ff2_bias", ) - # = keras.layers.Dense( - # , self._activation, name=f"{self.name}/ff1" - # ) - # self.ff2 = keras.layers.Dense( - # self.state_size, self._activation, name=f"{self.name}/ff2" - # ) - # if self.sparsity_mask is not None: - # self.ff1.build((None,)) - # self.ff2.build((None, self.sparsity_mask.shape[0])) self.time_a = keras.layers.Dense(self.state_size, name="time_a") self.time_b = keras.layers.Dense(self.state_size, name="time_b") input_shape = (None, self.state_size + input_dim) @@ -202,16 +184,18 @@ def call(self, inputs, states, **kwargs): return new_hidden, [new_hidden] def get_config(self): - config = super(CfCCell, self).get_config() - config["units"] = self.units - config["mode"] = self.mode - config["activation"] = self._activation - config["backbone_units"] = self._backbone_units - config["backbone_layers"] = self._backbone_layers - config["backbone_dropout"] = self._backbone_dropout - config["sparsity_mask"] = self.sparsity_mask - return config + config = { + "units": self.units, + "mode": self.mode, + "activation": self._activation, + "backbone_units": self._backbone_units, + "backbone_layers": self._backbone_layers, + "backbone_dropout": self._backbone_dropout, + "sparsity_mask": self.sparsity_mask, + } + base_config = super().get_config() + return {**base_config, **config} @classmethod - def from_config(cls, config): + def from_config(cls, config, custom_objects=None): return cls(**config) diff --git a/ncps/keras/mm_rnn.py b/ncps/keras/mm_rnn.py index ad567aa..38b511e 100644 --- a/ncps/keras/mm_rnn.py +++ b/ncps/keras/mm_rnn.py @@ -68,7 +68,10 @@ def build(self, sequences_shape, initial_state_shape=None): def call(self, sequences, initial_state=None, mask=None, training=False, **kwargs): memory_state, ct_state = initial_state - flat_ct_state = keras.ops.concatenate([ct_state], axis=-1) + if isinstance(ct_state, list): + flat_ct_state = keras.ops.concatenate(ct_state, axis=-1) + else: + flat_ct_state = ct_state z = ( keras.ops.matmul(sequences, self.input_kernel) + keras.ops.matmul(flat_ct_state, self.recurrent_kernel) diff --git a/ncps/keras/wired_cfc_cell.py b/ncps/keras/wired_cfc_cell.py index 7fbea37..7eb7489 100644 --- a/ncps/keras/wired_cfc_cell.py +++ b/ncps/keras/wired_cfc_cell.py @@ -150,12 +150,14 @@ def call(self, inputs, states, **kwargs): return output, new_hiddens def get_config(self): - config = super(WiredCfCCell, self).get_config() - config["wiring"] = self.wiring - config["fully_recurrent"] = self.fully_recurrent - config["mode"] = self.mode - config["activation"] = self._activation - return config + config = { + "wiring": self.wiring, + "fully_recurrent": self.fully_recurrent, + "mode": self.mode, + "activation": self._activation, + } + base_config = super().get_config() + return {**base_config, **config} @classmethod def from_config(cls, config): diff --git a/ncps/tests/test_keras.py b/ncps/tests/test_keras.py index ace65a8..a892268 100644 --- a/ncps/tests/test_keras.py +++ b/ncps/tests/test_keras.py @@ -377,6 +377,42 @@ def test_fit_bidirectional_auto_ncp_ltc_mixed_memory(): model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) +def test_fit_cfc_mixed_memory_fix_batch_size_no_sequences(): + data_x, data_y = prepare_test_data() + data_x = np.resize(data_x, (2, 48, 2)) + data_y = np.resize(data_y, (2, 1, 2)) + print("data_y.shape: ", str(data_y.shape)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(48, 2), batch_size=1), + CfC(28, + mixed_memory=True, + backbone_units=64, + backbone_dropout=0.3, + backbone_layers=2, + return_sequences=False), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=2, epochs=3) + + +def test_fit_bidirectional_cfc_with_sum(): + data_x, data_y = prepare_test_data() + print("data_y.shape: ", str(data_y.shape)) + model = keras.models.Sequential( + [ + keras.layers.InputLayer(input_shape=(None, 2)), + keras.layers.Bidirectional(CfC(28, return_sequences=False, unroll=True, mixed_memory=True), + merge_mode='sum'), + keras.layers.Dense(1), + ] + ) + model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") + model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) + + def test_wiring_graph_auto_ncp_ltc(): data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) @@ -840,29 +876,28 @@ def prune_details(config): assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())]) -def test_save_and_load_weights_only_bidirectional_cfc_ncp(): +def test_save_and_load_weights_only_bidirectional_cfc(): data_x, data_y = prepare_test_data() print("data_y.shape: ", str(data_y.shape)) - wiring = ncps.wirings.AutoNCP(28, 10) model = keras.models.Sequential( [ keras.layers.InputLayer(input_shape=(None, 2)), - keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Bidirectional(CfC(28, return_sequences=True)), keras.layers.Dense(1), ] ) model2 = keras.models.Sequential( [ keras.layers.InputLayer(input_shape=(None, 2)), - keras.layers.Bidirectional(CfC(wiring, return_sequences=True)), + keras.layers.Bidirectional(CfC(28, return_sequences=True)), keras.layers.Dense(1), ] ) model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model.fit(x=data_x, y=data_y, batch_size=1, epochs=3) - model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") keras_file = f"{inspect.currentframe().f_code.co_name}.keras" model.save(keras_file) + model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error") model2.load_weights(keras_file) assert all([np.array_equal(l, m) for (l, m) in zip(model2.get_weights(), model.get_weights())])