Skip to content

Commit

Permalink
Pass config to constructor when reviving custom functional model (#20321
Browse files Browse the repository at this point in the history
)

* Pass config to model constructor when loading functional model

When loading a model instantiated from a custom Model subclass its
config is not passed to it's constructor. This leads to some parameters
not being restored.

* Remove unused kwargs
  • Loading branch information
TrAyZeN authored Oct 12, 2024
1 parent 4b1e65e commit acceb5a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
27 changes: 20 additions & 7 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
if not all(is_input_keras_tensor(t) for t in flat_inputs):
inputs, outputs = clone_graph_nodes(inputs, outputs)

Function.__init__(self, inputs, outputs, name=name, **kwargs)
Function.__init__(self, inputs, outputs, name=name)

if trainable is not None:
self.trainable = trainable
Expand Down Expand Up @@ -494,16 +494,28 @@ def process_layer(layer_data):
# (e.g. a model such as A(B(A(B(x)))))
add_unprocessed_node(layer, node_data)

# Extract config used to instantiate Functional model from the config. The
# remaining config will be passed as keyword arguments to the Model
# constructor.
functional_config = {}
for key in ["layers", "input_layers", "output_layers"]:
functional_config[key] = config.pop(key)
for key in ["name", "trainable"]:
if key in config:
functional_config[key] = config.pop(key)
else:
functional_config[key] = None

# First, we create all layers and enqueue nodes to be processed
for layer_data in config["layers"]:
for layer_data in functional_config["layers"]:
process_layer(layer_data)

# Then we process nodes in order of layer depth.
# Nodes that cannot yet be processed (if the inbound node
# does not yet exist) are re-enqueued, and the process
# is repeated until all nodes are processed.
while unprocessed_nodes:
for layer_data in config["layers"]:
for layer_data in functional_config["layers"]:
layer = created_layers[layer_data["name"]]

# Process all nodes in layer, if not yet processed
Expand Down Expand Up @@ -532,8 +544,8 @@ def process_layer(layer_data):
del unprocessed_nodes[layer]

# Create list of input and output tensors and return new class
name = config.get("name")
trainable = config.get("trainable")
name = functional_config["name"]
trainable = functional_config["trainable"]

def get_tensor(layer_name, node_index, tensor_index):
assert layer_name in created_layers
Expand All @@ -558,8 +570,8 @@ def map_tensors(tensors):
return tuple([map_tensors(v) for v in tensors])
return [map_tensors(v) for v in tensors]

input_tensors = map_tensors(config["input_layers"])
output_tensors = map_tensors(config["output_layers"])
input_tensors = map_tensors(functional_config["input_layers"])
output_tensors = map_tensors(functional_config["output_layers"])
if isinstance(input_tensors, list) and len(input_tensors) == 1:
input_tensors = input_tensors[0]
if isinstance(output_tensors, list) and len(output_tensors) == 1:
Expand All @@ -570,6 +582,7 @@ def map_tensors(tensors):
outputs=output_tensors,
name=name,
trainable=trainable,
**config,
)


Expand Down
18 changes: 18 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def call(self, x):
)
self.assertIsInstance(new_model, Functional)

def test_reviving_functional_from_config_custom_model(self):
class CustomModel(Model):
def __init__(self, *args, param=1, **kwargs):
super().__init__(*args, **kwargs)
self.param = param

def get_config(self):
base_config = super().get_config()
config = {"param": self.param}
return base_config | config

inputs = layers.Input((3,))
outputs = layers.Dense(5)(inputs)
model = CustomModel(inputs=inputs, outputs=outputs, param=3)

new_model = CustomModel.from_config(model.get_config())
self.assertEqual(new_model.param, 3)

@parameterized.named_parameters(
("single_output_1", _get_model_single_output),
("single_output_2", _get_model_single_output),
Expand Down

0 comments on commit acceb5a

Please sign in to comment.