Skip to content

Commit

Permalink
Remove redundant reference to ocp.tree.serialize_tree(...) by remov…
Browse files Browse the repository at this point in the history
…ing dead code.

PiperOrigin-RevId: 697276504
  • Loading branch information
niketkumar authored and t5-copybara committed Nov 23, 2024
1 parent c723ab0 commit 3acb982
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
23 changes: 6 additions & 17 deletions t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@

import enum
import os
from typing import Any, BinaryIO, Optional, Tuple, Union
from typing import Any, BinaryIO, Optional, Union

from absl import logging
from etils import epath
import jax
import msgpack
import orbax.checkpoint as ocp
from tensorflow.io import gfile
Expand Down Expand Up @@ -257,27 +256,21 @@ def _is_supported_empty_value(value: Any) -> bool:
return ocp.type_handlers.is_supported_empty_value(value)


def get_restore_parameters(
directory: epath.Path,
structure: PyTree,
) -> Tuple[PyTree, PyTree]:
"""Construct parameters needed for restoration.
def get_restore_parameters(directory: epath.Path, structure: PyTree) -> PyTree:
"""Construct ParamInfos tree needed for restoration.
ParamInfos are
constructed from the structure of the original checkpoint, and restore_args
are serialized to a tree structure compatible with param_infos and structure.
ParamInfos are constructed from the structure of the original checkpoint.
Args:
directory: Checkpoint directory.
structure: The structure of the original checkpoint.
Returns:
Tuple of param_infos, and restore_args.
PyTree of `ParamInfo`.
"""
flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True)
param_names = ocp.tree.get_param_names(structure)
flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True)
restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure)
flat_param_infos = {}
is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory)
ts_context = ocp.type_handlers.get_ts_context()
Expand Down Expand Up @@ -305,9 +298,5 @@ def _get_param_info(

for key, meta in flat_structure.items():
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
restore_args = ocp.tree.serialize_tree(restore_args, keep_empty_nodes=True)

return (
ocp.tree.from_flat_dict(flat_param_infos, target=structure),
restore_args,
)
return ocp.tree.from_flat_dict(flat_param_infos, target=structure)
2 changes: 1 addition & 1 deletion t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,7 @@ def _modify_orbax_param_info(info, value):
return info

item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args)
param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = jax.tree.map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
)
Expand Down

0 comments on commit 3acb982

Please sign in to comment.