Skip to content

Commit

Permalink
Remove redundant reference to ocp.tree.serialize_tree(...).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697276504
  • Loading branch information
niketkumar authored and t5-copybara committed Nov 17, 2024
1 parent c723ab0 commit fb77046
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 18 deletions.
25 changes: 8 additions & 17 deletions t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from absl import logging
from etils import epath
import jax
import msgpack
import orbax.checkpoint as ocp
from tensorflow.io import gfile
Expand Down Expand Up @@ -261,24 +260,17 @@ def get_restore_parameters(
directory: epath.Path,
structure: PyTree,
) -> Tuple[PyTree, PyTree]:
"""Construct parameters needed for restoration.
"""Construct ParamInfos tree needed for restoration.
ParamInfos are
constructed from the structure of the original checkpoint, and restore_args
are serialized to a tree structure compatible with param_infos and structure.
ParamInfos are constructed from the structure of the original checkpoint.
Args:
directory: Checkpoint directory.
structure: The structure of the original checkpoint.
Returns:
Tuple of param_infos, and restore_args.
PyTree of `ParamInfo`.
"""
flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True)
param_names = ocp.tree.get_param_names(structure)
flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True)
restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure)
flat_param_infos = {}
is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory)
ts_context = ocp.type_handlers.get_ts_context()

Expand All @@ -303,11 +295,10 @@ def _get_param_info(
ts_context=ts_context,
)

flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True)
param_names = ocp.tree.get_param_names(structure)
flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True)
flat_param_infos = {}
for key, meta in flat_structure.items():
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
restore_args = ocp.tree.serialize_tree(restore_args, keep_empty_nodes=True)

return (
ocp.tree.from_flat_dict(flat_param_infos, target=structure),
restore_args,
)
return ocp.tree.from_flat_dict(flat_param_infos, target=structure)
2 changes: 1 addition & 1 deletion t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,7 @@ def _modify_orbax_param_info(info, value):
return info

item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args)
param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = jax.tree.map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
)
Expand Down

0 comments on commit fb77046

Please sign in to comment.