Skip to content

Commit

Permalink
Fix bug in MixedMemory when concatenating list of tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
lettercode committed Aug 6, 2024
1 parent ad72bb6 commit 5a1c4d9
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
runs-on: ubuntu-latest

container:
image: tensorflow/tensorflow:2.16.1
image: tensorflow/tensorflow:2.17.0
env:
KERAS_BACKEND: tensorflow
volumes:
Expand Down
62 changes: 23 additions & 39 deletions ncps/keras/cfc_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,18 @@ def build(self, input_shape):
else:
cat_shape = int(self.state_size + input_dim)

self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)

if self.mode == "pure":
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)
self.w_tau = self.add_weight(
shape=(1, self.state_size),
initializer=keras.initializers.Zeros(),
Expand All @@ -121,16 +122,6 @@ def build(self, input_shape):
name="A",
)
else:
self.ff1_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
name="ff1_weight",
)
self.ff1_bias = self.add_weight(
shape=(self.state_size,),
initializer="zeros",
name="ff1_bias",
)
self.ff2_kernel = self.add_weight(
shape=(cat_shape, self.state_size),
initializer="glorot_uniform",
Expand All @@ -142,15 +133,6 @@ def build(self, input_shape):
name="ff2_bias",
)

# = keras.layers.Dense(
# , self._activation, name=f"{self.name}/ff1"
# )
# self.ff2 = keras.layers.Dense(
# self.state_size, self._activation, name=f"{self.name}/ff2"
# )
# if self.sparsity_mask is not None:
# self.ff1.build((None,))
# self.ff2.build((None, self.sparsity_mask.shape[0]))
self.time_a = keras.layers.Dense(self.state_size, name="time_a")
self.time_b = keras.layers.Dense(self.state_size, name="time_b")
input_shape = (None, self.state_size + input_dim)
Expand Down Expand Up @@ -202,16 +184,18 @@ def call(self, inputs, states, **kwargs):
return new_hidden, [new_hidden]

def get_config(self):
config = super(CfCCell, self).get_config()
config["units"] = self.units
config["mode"] = self.mode
config["activation"] = self._activation
config["backbone_units"] = self._backbone_units
config["backbone_layers"] = self._backbone_layers
config["backbone_dropout"] = self._backbone_dropout
config["sparsity_mask"] = self.sparsity_mask
return config
config = {
"units": self.units,
"mode": self.mode,
"activation": self._activation,
"backbone_units": self._backbone_units,
"backbone_layers": self._backbone_layers,
"backbone_dropout": self._backbone_dropout,
"sparsity_mask": self.sparsity_mask,
}
base_config = super().get_config()
return {**base_config, **config}

@classmethod
def from_config(cls, config):
def from_config(cls, config, custom_objects=None):
return cls(**config)
5 changes: 4 additions & 1 deletion ncps/keras/mm_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def build(self, sequences_shape, initial_state_shape=None):

def call(self, sequences, initial_state=None, mask=None, training=False, **kwargs):
memory_state, ct_state = initial_state
flat_ct_state = keras.ops.concatenate([ct_state], axis=-1)
if isinstance(ct_state, list):
flat_ct_state = keras.ops.concatenate(ct_state, axis=-1)
else:
flat_ct_state = ct_state
z = (
keras.ops.matmul(sequences, self.input_kernel)
+ keras.ops.matmul(flat_ct_state, self.recurrent_kernel)
Expand Down
14 changes: 8 additions & 6 deletions ncps/keras/wired_cfc_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ def call(self, inputs, states, **kwargs):
return output, new_hiddens

def get_config(self):
config = super(WiredCfCCell, self).get_config()
config["wiring"] = self.wiring
config["fully_recurrent"] = self.fully_recurrent
config["mode"] = self.mode
config["activation"] = self._activation
return config
config = {
"wiring": self.wiring,
"fully_recurrent": self.fully_recurrent,
"mode": self.mode,
"activation": self._activation,
}
base_config = super().get_config()
return {**base_config, **config}

@classmethod
def from_config(cls, config):
Expand Down
45 changes: 40 additions & 5 deletions ncps/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,42 @@ def test_fit_bidirectional_auto_ncp_ltc_mixed_memory():
model.fit(x=data_x, y=data_y, batch_size=1, epochs=3)


def test_fit_cfc_mixed_memory_fix_batch_size_no_sequences():
data_x, data_y = prepare_test_data()
data_x = np.resize(data_x, (2, 48, 2))
data_y = np.resize(data_y, (2, 1, 2))
print("data_y.shape: ", str(data_y.shape))
model = keras.models.Sequential(
[
keras.layers.InputLayer(input_shape=(48, 2), batch_size=1),
CfC(28,
mixed_memory=True,
backbone_units=64,
backbone_dropout=0.3,
backbone_layers=2,
return_sequences=False),
keras.layers.Dense(1),
]
)
model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error")
model.fit(x=data_x, y=data_y, batch_size=2, epochs=3)


def test_fit_bidirectional_cfc_with_sum():
data_x, data_y = prepare_test_data()
print("data_y.shape: ", str(data_y.shape))
model = keras.models.Sequential(
[
keras.layers.InputLayer(input_shape=(None, 2)),
keras.layers.Bidirectional(CfC(28, return_sequences=False, unroll=True, mixed_memory=True),
merge_mode='sum'),
keras.layers.Dense(1),
]
)
model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error")
model.fit(x=data_x, y=data_y, batch_size=1, epochs=3)


def test_wiring_graph_auto_ncp_ltc():
data_x, data_y = prepare_test_data()
print("data_y.shape: ", str(data_y.shape))
Expand Down Expand Up @@ -840,29 +876,28 @@ def prune_details(config):
assert all([np.array_equal(l, m) for (l, m) in zip(loaded_model.get_weights(), model.get_weights())])


def test_save_and_load_weights_only_bidirectional_cfc_ncp():
def test_save_and_load_weights_only_bidirectional_cfc():
data_x, data_y = prepare_test_data()
print("data_y.shape: ", str(data_y.shape))
wiring = ncps.wirings.AutoNCP(28, 10)
model = keras.models.Sequential(
[
keras.layers.InputLayer(input_shape=(None, 2)),
keras.layers.Bidirectional(CfC(wiring, return_sequences=True)),
keras.layers.Bidirectional(CfC(28, return_sequences=True)),
keras.layers.Dense(1),
]
)
model2 = keras.models.Sequential(
[
keras.layers.InputLayer(input_shape=(None, 2)),
keras.layers.Bidirectional(CfC(wiring, return_sequences=True)),
keras.layers.Bidirectional(CfC(28, return_sequences=True)),
keras.layers.Dense(1),
]
)
model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error")
model.fit(x=data_x, y=data_y, batch_size=1, epochs=3)
model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error")
keras_file = f"{inspect.currentframe().f_code.co_name}.keras"
model.save(keras_file)
model2.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error")
model2.load_weights(keras_file)

assert all([np.array_equal(l, m) for (l, m) in zip(model2.get_weights(), model.get_weights())])

0 comments on commit 5a1c4d9

Please sign in to comment.