Skip to content

Commit

Permalink
Add branch to detect T5 checkpoints in OrbaxCheckpointManager which c…
Browse files Browse the repository at this point in the history
…hecks if the checkpoint is T5. In the case that it is, the legacy checkpointer is used to load the checkpoint.

PiperOrigin-RevId: 550325349
  • Loading branch information
k-w-w authored and t5-copybara committed Aug 8, 2023
1 parent a257cac commit 22036d7
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 28 deletions.
159 changes: 158 additions & 1 deletion t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@
removal process.
"""

import enum
import os
from typing import Any, BinaryIO, Optional

from absl import logging

from etils import epath
import msgpack
from tensorflow.io import gfile

# PINNED file in the checkpoint directory indicates that the checkpoint should
# not be removed during the automatic pruning of old checkpoints.
_PINNED_CHECKPOINT_FILENAME = 'PINNED'

PyTree = Any


def pinned_checkpoint_filepath(ckpt_dir: str) -> str:
"""Full path of the pinned checkpoint file."""
Expand Down Expand Up @@ -89,3 +94,155 @@ def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None:
gfile.remove(file)
else:
logging.info('Keeping pinned checkpoint: %s', ckpt_dir)


def _read_msgpack_keys(file_like: BinaryIO) -> PyTree:
"""Returns a tree containing all keys but no values from a msgpack file."""
unpacker = msgpack.Unpacker(file_like)
num_keys = unpacker.read_map_header()
ret = {}

# Contains references to the parent tree for each key to visit in the
# msgpack file traversal.
visit_stack = [ret for _ in range(num_keys)]
while visit_stack:
parent_dict = visit_stack.pop()
key = unpacker.unpack()
if isinstance(key, bytes):
key = str(unpacker.unpack(), 'utf-8')

# Check if the value object is map.
try:
n = unpacker.read_map_header()
ref = parent_dict[key] = {}
visit_stack.extend(ref for _ in range(n))
except msgpack.UnpackValueError:
# Not a map so skip unpacking the value object and record the current key.
unpacker.skip()
parent_dict[key] = None

return ret


def _contains_ts_spec(tree: PyTree) -> bool:
"""Returns whether the a Pytree contains a serialized ts.Spec object."""
to_visit = [tree]
while to_visit:
cur = to_visit.pop()
if cur.keys() >= {'driver', 'kvstore', 'metadata'}:
return True
to_visit.extend(v for v in cur.values() if isinstance(v, dict))
return False


# Constant copied from orbax/checkpoint/pytree_checkpoint_handler.py
_METADATA_FILE = '_METADATA'


def _contains_orbax_metadata(ckpt_path: str) -> bool:
metadata = os.path.join(os.path.dirname(ckpt_path), _METADATA_FILE)
return gfile.exists(metadata)


class CheckpointTypes(enum.Enum):
ORBAX = 'ORBAX'
T5X = 'T5X'
T5X_TF = 'T5X_TF'


def _warn_if_unexpected_type(
checkpoint_path, checkpoint_type, expected, extra_warn_log
):
"""Warns the user if unexpected type found."""
if expected is None or checkpoint_type == expected:
return

logging.warning(
'Expected the checkpoint at %s to be %s format, but'
' the actual detected format was %s.',
checkpoint_path,
expected,
checkpoint_type,
)
logging.warning(extra_warn_log)


def detect_checkpoint_type(
checkpoint_path: epath.PathLike, expected: Optional[CheckpointTypes] = None
) -> CheckpointTypes:
"""Returns the checkpoint type by reading the `.checkpoint` metadata file.
Args:
checkpoint_path: The path of the `.checkpoint` file.
expected: The expected checkpoint type. If the checkpoint type is not as
expected, this function will log a warning but will not raise an error.
Returns:
The checkpoint type.
"""
if _contains_orbax_metadata(checkpoint_path):
checkpoint_type = CheckpointTypes.ORBAX
_warn_if_unexpected_type(
checkpoint_path,
checkpoint_type,
expected,
f'Found `{_METADATA_FILE}` in the checkpoint directory, which only '
'appears in Orbax checkpoints',
)
return checkpoint_type

with gfile.GFile(checkpoint_path, 'rb') as fp:
raw_contents = fp.read(21)
if raw_contents == b'model_checkpoint_path':
checkpoint_type = CheckpointTypes.T5X_TF
_warn_if_unexpected_type(
checkpoint_path,
checkpoint_type,
expected,
'The checkpoint file was not a msgpack, and had the string '
'"model_checkpoint_path", so it was assumed to be in the T5X '
'TensorFlow format.',
)
return checkpoint_type

# Assume that if the msgpack file has exactly 'version' and 'optimizer' as
# keys, it is a T5X checkpoint. Checkpoints that were created a long time
# ago may not contain these keys, so there is a backup ts.Spec check
# as well.
fp.seek(0)
key_tree = _read_msgpack_keys(fp)
if set(key_tree.keys()) == {'version', 'optimizer'}:
checkpoint_type = CheckpointTypes.T5X
_warn_if_unexpected_type(
checkpoint_path,
checkpoint_type,
expected,
'Top-level keys in the msgpack file were "version" and "optimizer", '
'thus the checkpoint was assumed to be in the T5X format.',
)
return checkpoint_type
elif _contains_ts_spec(key_tree):
# If the checkpoint contains a ts.Spec, it could either be a T5X
# checkpoint or an early version Flax checkpoint. The latter is
# essentially deprecated but should also be handled by the T5X
# Checkpointer, so we return T5X here for simplicity.
checkpoint_type = CheckpointTypes.T5X
_warn_if_unexpected_type(
checkpoint_path,
checkpoint_type,
expected,
'Found ts.Spec in the checkpoint msgpack file, thus the checkpoint'
' was assumed to be in the T5X format.',
)
return checkpoint_type
else:
checkpoint_type = CheckpointTypes.ORBAX
_warn_if_unexpected_type(
checkpoint_path,
checkpoint_type,
expected,
'Did not detect ts.Spec nor the {"version", "optimizer"} keys in the'
'checkpoint msgpack file, so the checkpoint was assumed to be '
'written with Orbax.',
)
return checkpoint_type
41 changes: 41 additions & 0 deletions t5x/checkpoint_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,46 @@ def test_remove_dataset_checkpoint_pinned(self):
self.assertTrue(gfile.exists(self.train_ds_file))
self.assertTrue(gfile.exists(self.ckpt_dir_path))

def test_detect_checkpoint_type(self):
tf_ckpt = os.path.join(TESTDATA, "mtf_tiny_t5", "checkpoint")
orbax_ckpt = os.path.join(TESTDATA, "tiny_orbax", "1", "checkpoint")
t5_ckpt = os.path.join(TESTDATA, "tiny_t5", "checkpoint_1", "checkpoint")

ret = checkpoint_utils.detect_checkpoint_type(
t5_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X
)
self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X)

ret = checkpoint_utils.detect_checkpoint_type(
tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF
)
self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X_TF)

ret = checkpoint_utils.detect_checkpoint_type(
orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.ORBAX
)
self.assertEqual(ret, checkpoint_utils.CheckpointTypes.ORBAX)

with self.assertLogs(level="WARN") as log_output:
checkpoint_utils.detect_checkpoint_type(
tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X
)
self.assertRegex(
log_output[0][0].message,
".*to be CheckpointTypes.T5X format, but the actual detected format was"
" CheckpointTypes.T5X_TF.*",
)

with self.assertLogs(level="WARN") as log_output:
checkpoint_utils.detect_checkpoint_type(
orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF
)
self.assertRegex(
log_output[0][0].message,
".*to be CheckpointTypes.T5X_TF format, but the actual detected format"
" was CheckpointTypes.ORBAX.*",
)


if __name__ == "__main__":
absltest.main()
96 changes: 69 additions & 27 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ def _sync_global_devices(name: str) -> None:


def get_checkpoint_dir(
checkpoints_dir: str,
checkpoints_dir: epath.PathLike,
step: int,
step_format_fixed_length: Optional[int] = None,
) -> str:
) -> epath.PathLike:
"""Returns path to a checkpoint dir given a parent directory and step."""
step_str = (
f'{step:0{step_format_fixed_length}d}'
Expand Down Expand Up @@ -366,7 +366,7 @@ def _get_spec(directory: str, arr: Any, name: str,
metadata: Dict[str, Any]) -> ts.Spec:
"""Get ts.Spec from array and name information."""

if directory.startswith('gs://'):
if os.fspath(directory).startswith('gs://'):
spec = {
'driver': 'zarr',
'dtype': jnp.dtype(arr.dtype).name,
Expand Down Expand Up @@ -555,18 +555,20 @@ class Checkpointer(object):
oldest ones will be automatically deleted to save space.
"""

def __init__(self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
dataset_iterator: Optional[
Union[tf.data.Iterator,
clu.data.dataset_iterator.DatasetIterator]] = None,
*,
keep: Optional[int] = None,
save_dtype: jnp.dtype = np.float32,
restore_dtype: Optional[jnp.dtype] = None,
keep_dataset_checkpoints: Optional[int] = None):
def __init__(
self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: epath.PathLike,
dataset_iterator: Optional[
Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator]
] = None,
*,
keep: Optional[int] = None,
save_dtype: jnp.dtype = np.float32,
restore_dtype: Optional[jnp.dtype] = None,
keep_dataset_checkpoints: Optional[int] = None,
):
"""Checkpointer constructor.
Args:
Expand Down Expand Up @@ -708,7 +710,7 @@ def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec):
self._get_state_dict_for_save(self._train_state.state_dict()),
self._partitioner.get_mesh_axes(self._train_state).state_dict())

def _get_checkpoint_dir(self, step: int) -> str:
def _get_checkpoint_dir(self, step: int) -> epath.PathLike:
return get_checkpoint_dir(self.checkpoints_dir, step)

def all_steps(self) -> Sequence[int]:
Expand Down Expand Up @@ -1035,25 +1037,30 @@ def restore(
if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path):
raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}')

ckpt_type = checkpoint_utils.detect_checkpoint_type(
ckpt_path, expected=checkpoint_utils.CheckpointTypes.T5X
)
if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF:
raise ValueError(
'Attempting to restore a TensorFlow checkpoint as a native T5X '
'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: '
f'{ckpt_path}'
)
# Don't error out here for Orbax-detected checkpoint (there are edge cases
# where all values are stored in the msgpack and the checkpoint file can be
# loaded by both the Orbax and T5X checkpointer).
logging.info('Restoring from checkpoint: %s', ckpt_path)

with gfile.GFile(ckpt_path, 'rb') as fp:
# TODO(adarob): Use threaded reading as in flax.checkpoints.
raw_contents = fp.read()
if raw_contents.startswith(b'model_checkpoint_path'):
raise ValueError(
'Attempting to restore a TensorFlow checkpoint as a native T5X '
'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' +
ckpt_path)

# `ckpt_contents['optimizer']` is a pytree with a realized np.array for
# leaves (params or states) written as msgpack and a ts.Spec (in a dict)
# for leaves written by TensorStore.
ckpt_contents = serialization.msgpack_restore(raw_contents)
ckpt_contents = serialization.msgpack_restore(fp.read())

# If reading a ckpt that was written with gfile driver but the current
# session uses the gcs driver, convert the ckpt's driver to gcs.
if ckpt_dir.startswith('gs://'):
if os.fspath(ckpt_dir).startswith('gs://'):
ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents)
# If a ckpt was saved in gcs and is being loaded locally, then convert the
# driver to file or gfile. If the ckpt was not saved in gcs, do not change.
Expand Down Expand Up @@ -1716,7 +1723,9 @@ def fake_param_info(maybe_tspec: Any) -> Optional[_ParameterInfo]:
axes=None)


def find_checkpoint(path: str, step: Optional[int] = None) -> str:
def find_checkpoint(
path: epath.PathLike, step: Optional[int] = None
) -> epath.PathLike:
"""Find the checkpoint file based on paths and steps.
Args:
Expand Down Expand Up @@ -1791,7 +1800,7 @@ def load_t5x_checkpoint(

# If reading a ckpt that was written with gfile driver but the current
# session uses the gcs driver, convert the ckpt's driver to gcs.
if path.startswith('gs://'):
if os.fspath(path).startswith('gs://'):
ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents)
# If a ckpt was saved in gcs and is being loaded locally, then convert the
# driver to file or gfile. If the ckpt was not saved in gcs, do not change.
Expand Down Expand Up @@ -2195,6 +2204,9 @@ def best_fn(metrics):
step_prefix='checkpoint',
)
options.metric_name_to_monitor = metric_name_to_monitor

if not gfile.isdir(directory):
directory = os.path.dirname(directory)
self._manager = self._CheckpointManagerImpl(
directory=directory, checkpointers=checkpointers, options=options
)
Expand Down Expand Up @@ -2315,6 +2327,36 @@ def restore(
if path is not None:
directory, step = get_step_from_checkpoint_dir(os.fspath(path))

# Check for legacy T5X checkpoint: If so, use legacy T5X
# checkpointer to restore the state. The following exclusive features of T5X
# checkpoint are skipped: DatasetIterator, [add more here when discovered]
try:
ckpt_path = find_checkpoint(directory, step)
except ValueError:
# `find_checkpoint` fails if the `.checkpoint` file isn't directly in
# the checkpoint directory. In this case, leave path as None and skip
# the legacy T5X checkpoint check.
ckpt_path = None

if ckpt_path is not None:
ckpt_type = checkpoint_utils.detect_checkpoint_type(
ckpt_path, expected=checkpoint_utils.CheckpointTypes.ORBAX
)
if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF:
raise ValueError(
'Attempting to restore a TensorFlow checkpoint as a native T5X '
'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: '
+ ckpt_path
)
elif ckpt_type is checkpoint_utils.CheckpointTypes.T5X:
legacy_checkpointer = Checkpointer(
self._train_state,
self._partitioner,
self.directory,
restore_dtype=self._restore_dtype,
)
return legacy_checkpointer.restore(path=path)

state_dict = self._train_state.state_dict()
# Returns a state dict rather than a train state.
param_infos = _construct_orbax_param_infos(
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added t5x/testdata/tiny_orbax/1/_optimizer.state.step/0
Binary file not shown.
1 change: 1 addition & 0 deletions t5x/testdata/tiny_orbax/1/checkpoint
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
��_optimizer��state��param_states��bias�0PLACEHOLDER://_optimizer.state.param_states.bias�kernel�2PLACEHOLDER://_optimizer.state.param_states.kernel�step�0PLACEHOLDER://_optimizer.state.param_states.step�step�#PLACEHOLDER://_optimizer.state.step�target��flax_mutables��flax_mutables_axes��params_axes�
Binary file added t5x/testdata/tiny_t5/checkpoint_1/checkpoint
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 22036d7

Please sign in to comment.