Skip to content

Commit

Permalink
Fix serialization / deserialization. (#20406)
Browse files Browse the repository at this point in the history
- Serialization was not taking the registered name and package from the registry.
- Deserialization was selecting symbols by postfix as a fallback.
  • Loading branch information
hertschuh authored Oct 25, 2024
1 parent 56eaab3 commit 0c2bdff
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 11 deletions.
2 changes: 1 addition & 1 deletion keras/src/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def test_saved_module_paths_and_class_names(self):
)
self.assertEqual(
config_dict["compile_config"]["loss"]["config"],
"my_mean_squared_error",
"my_custom_package>my_mean_squared_error",
)

@pytest.mark.requires_trainable_backend
Expand Down
11 changes: 1 addition & 10 deletions keras/src/saving/serialization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def _get_class_or_fn_config(obj):
"""Return the object's config depending on its type."""
# Functions / lambdas:
if isinstance(obj, types.FunctionType):
return obj.__name__
return object_registration.get_registered_name(obj)
# All classes:
if hasattr(obj, "get_config"):
config = obj.get_config()
Expand Down Expand Up @@ -781,15 +781,6 @@ def _retrieve_class_or_fn(
if obj is not None:
return obj

# Retrieval of registered custom function in a package
filtered_dict = {
k: v
for k, v in custom_objects.items()
if k.endswith(full_config["config"])
}
if filtered_dict:
return next(iter(filtered_dict.values()))

# Otherwise, attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
try:
Expand Down

0 comments on commit 0c2bdff

Please sign in to comment.