diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index 7fe256461..60e517dbf 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -18,16 +18,21 @@ removal process. """ +import enum import os +from typing import Any, BinaryIO, Optional from absl import logging - +from etils import epath +import msgpack from tensorflow.io import gfile # PINNED file in the checkpoint directory indicates that the checkpoint should # not be removed during the automatic pruning of old checkpoints. _PINNED_CHECKPOINT_FILENAME = 'PINNED' +PyTree = Any + def pinned_checkpoint_filepath(ckpt_dir: str) -> str: """Full path of the pinned checkpoint file.""" @@ -89,3 +94,155 @@ def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None: gfile.remove(file) else: logging.info('Keeping pinned checkpoint: %s', ckpt_dir) + + +def _read_msgpack_keys(file_like: BinaryIO) -> PyTree: + """Returns a tree containing all keys but no values from a msgpack file.""" + unpacker = msgpack.Unpacker(file_like) + num_keys = unpacker.read_map_header() + ret = {} + + # Contains references to the parent tree for each key to visit in the + # msgpack file traversal. + visit_stack = [ret for _ in range(num_keys)] + while visit_stack: + parent_dict = visit_stack.pop() + key = unpacker.unpack() + if isinstance(key, bytes): + key = str(unpacker.unpack(), 'utf-8') + + # Check if the value object is map. + try: + n = unpacker.read_map_header() + ref = parent_dict[key] = {} + visit_stack.extend(ref for _ in range(n)) + except msgpack.UnpackValueError: + # Not a map so skip unpacking the value object and record the current key. + unpacker.skip() + parent_dict[key] = None + + return ret + + +def _contains_ts_spec(tree: PyTree) -> bool: + """Returns whether the a Pytree contains a serialized ts.Spec object.""" + to_visit = [tree] + while to_visit: + cur = to_visit.pop() + if cur.keys() >= {'driver', 'kvstore', 'metadata'}: + return True + to_visit.extend(v for v in cur.values() if isinstance(v, dict)) + return False + + +# Constant copied from orbax/checkpoint/pytree_checkpoint_handler.py +_METADATA_FILE = '_METADATA' + + +def _contains_orbax_metadata(ckpt_path: str) -> bool: + metadata = os.path.join(os.path.dirname(ckpt_path), _METADATA_FILE) + return gfile.exists(metadata) + + +class CheckpointTypes(enum.Enum): + ORBAX = 'ORBAX' + T5X = 'T5X' + T5X_TF = 'T5X_TF' + + +def _warn_if_unexpected_type( + checkpoint_path, checkpoint_type, expected, extra_warn_log +): + """Warns the user if unexpected type found.""" + if expected is None or checkpoint_type == expected: + return + + logging.warning( + 'Expected the checkpoint at %s to be %s format, but' + ' the actual detected format was %s.', + checkpoint_path, + expected, + checkpoint_type, + ) + logging.warning(extra_warn_log) + + +def detect_checkpoint_type( + checkpoint_path: epath.PathLike, expected: Optional[CheckpointTypes] = None +) -> CheckpointTypes: + """Returns the checkpoint type by reading the `.checkpoint` metadata file. + + Args: + checkpoint_path: The path of the `.checkpoint` file. + expected: The expected checkpoint type. If the checkpoint type is not as + expected, this function will log a warning but will not raise an error. + + Returns: + The checkpoint type. + """ + if _contains_orbax_metadata(checkpoint_path): + checkpoint_type = CheckpointTypes.ORBAX + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + f'Found `{_METADATA_FILE}` in the checkpoint directory, which only ' + 'appears in Orbax checkpoints', + ) + return checkpoint_type + + with gfile.GFile(checkpoint_path, 'rb') as fp: + raw_contents = fp.read(21) + if raw_contents == b'model_checkpoint_path': + checkpoint_type = CheckpointTypes.T5X_TF + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'The checkpoint file was not a msgpack, and had the string ' + '"model_checkpoint_path", so it was assumed to be in the T5X ' + 'TensorFlow format.', + ) + return checkpoint_type + + # Assume that if the msgpack file has exactly 'version' and 'optimizer' as + # keys, it is a T5X checkpoint. Checkpoints that were created a long time + # ago may not contain these keys, so there is a backup ts.Spec check + # as well. + fp.seek(0) + key_tree = _read_msgpack_keys(fp) + if set(key_tree.keys()) == {'version', 'optimizer'}: + checkpoint_type = CheckpointTypes.T5X + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'Top-level keys in the msgpack file were "version" and "optimizer", ' + 'thus the checkpoint was assumed to be in the T5X format.', + ) + return checkpoint_type + elif _contains_ts_spec(key_tree): + # If the checkpoint contains a ts.Spec, it could either be a T5X + # checkpoint or an early version Flax checkpoint. The latter is + # essentially deprecated but should also be handled by the T5X + # Checkpointer, so we return T5X here for simplicity. + checkpoint_type = CheckpointTypes.T5X + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'Found ts.Spec in the checkpoint msgpack file, thus the checkpoint' + ' was assumed to be in the T5X format.', + ) + return checkpoint_type + else: + checkpoint_type = CheckpointTypes.ORBAX + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'Did not detect ts.Spec nor the {"version", "optimizer"} keys in the' + 'checkpoint msgpack file, so the checkpoint was assumed to be ' + 'written with Orbax.', + ) + return checkpoint_type diff --git a/t5x/checkpoint_utils_test.py b/t5x/checkpoint_utils_test.py index fe2deaf72..ef54689f3 100644 --- a/t5x/checkpoint_utils_test.py +++ b/t5x/checkpoint_utils_test.py @@ -145,5 +145,46 @@ def test_remove_dataset_checkpoint_pinned(self): self.assertTrue(gfile.exists(self.train_ds_file)) self.assertTrue(gfile.exists(self.ckpt_dir_path)) + def test_detect_checkpoint_type(self): + tf_ckpt = os.path.join(TESTDATA, "mtf_tiny_t5", "checkpoint") + orbax_ckpt = os.path.join(TESTDATA, "tiny_orbax", "1", "checkpoint") + t5_ckpt = os.path.join(TESTDATA, "tiny_t5", "checkpoint_1", "checkpoint") + + ret = checkpoint_utils.detect_checkpoint_type( + t5_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X + ) + self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X) + + ret = checkpoint_utils.detect_checkpoint_type( + tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF + ) + self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X_TF) + + ret = checkpoint_utils.detect_checkpoint_type( + orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.ORBAX + ) + self.assertEqual(ret, checkpoint_utils.CheckpointTypes.ORBAX) + + with self.assertLogs(level="WARN") as log_output: + checkpoint_utils.detect_checkpoint_type( + tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X + ) + self.assertRegex( + log_output[0][0].message, + ".*to be CheckpointTypes.T5X format, but the actual detected format was" + " CheckpointTypes.T5X_TF.*", + ) + + with self.assertLogs(level="WARN") as log_output: + checkpoint_utils.detect_checkpoint_type( + orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF + ) + self.assertRegex( + log_output[0][0].message, + ".*to be CheckpointTypes.T5X_TF format, but the actual detected format" + " was CheckpointTypes.ORBAX.*", + ) + + if __name__ == "__main__": absltest.main() diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 8c4f825af..1842a8558 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -228,10 +228,10 @@ def _sync_global_devices(name: str) -> None: def get_checkpoint_dir( - checkpoints_dir: str, + checkpoints_dir: epath.PathLike, step: int, step_format_fixed_length: Optional[int] = None, -) -> str: +) -> epath.PathLike: """Returns path to a checkpoint dir given a parent directory and step.""" step_str = ( f'{step:0{step_format_fixed_length}d}' @@ -366,7 +366,7 @@ def _get_spec(directory: str, arr: Any, name: str, metadata: Dict[str, Any]) -> ts.Spec: """Get ts.Spec from array and name information.""" - if directory.startswith('gs://'): + if os.fspath(directory).startswith('gs://'): spec = { 'driver': 'zarr', 'dtype': jnp.dtype(arr.dtype).name, @@ -555,18 +555,20 @@ class Checkpointer(object): oldest ones will be automatically deleted to save space. """ - def __init__(self, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - checkpoints_dir: str, - dataset_iterator: Optional[ - Union[tf.data.Iterator, - clu.data.dataset_iterator.DatasetIterator]] = None, - *, - keep: Optional[int] = None, - save_dtype: jnp.dtype = np.float32, - restore_dtype: Optional[jnp.dtype] = None, - keep_dataset_checkpoints: Optional[int] = None): + def __init__( + self, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + checkpoints_dir: epath.PathLike, + dataset_iterator: Optional[ + Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] + ] = None, + *, + keep: Optional[int] = None, + save_dtype: jnp.dtype = np.float32, + restore_dtype: Optional[jnp.dtype] = None, + keep_dataset_checkpoints: Optional[int] = None, + ): """Checkpointer constructor. Args: @@ -708,7 +710,7 @@ def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec): self._get_state_dict_for_save(self._train_state.state_dict()), self._partitioner.get_mesh_axes(self._train_state).state_dict()) - def _get_checkpoint_dir(self, step: int) -> str: + def _get_checkpoint_dir(self, step: int) -> epath.PathLike: return get_checkpoint_dir(self.checkpoints_dir, step) def all_steps(self) -> Sequence[int]: @@ -1035,25 +1037,30 @@ def restore( if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path): raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}') + ckpt_type = checkpoint_utils.detect_checkpoint_type( + ckpt_path, expected=checkpoint_utils.CheckpointTypes.T5X + ) + if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF: + raise ValueError( + 'Attempting to restore a TensorFlow checkpoint as a native T5X ' + 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + f'{ckpt_path}' + ) + # Don't error out here for Orbax-detected checkpoint (there are edge cases + # where all values are stored in the msgpack and the checkpoint file can be + # loaded by both the Orbax and T5X checkpointer). logging.info('Restoring from checkpoint: %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: # TODO(adarob): Use threaded reading as in flax.checkpoints. - raw_contents = fp.read() - if raw_contents.startswith(b'model_checkpoint_path'): - raise ValueError( - 'Attempting to restore a TensorFlow checkpoint as a native T5X ' - 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + - ckpt_path) - # `ckpt_contents['optimizer']` is a pytree with a realized np.array for # leaves (params or states) written as msgpack and a ts.Spec (in a dict) # for leaves written by TensorStore. - ckpt_contents = serialization.msgpack_restore(raw_contents) + ckpt_contents = serialization.msgpack_restore(fp.read()) # If reading a ckpt that was written with gfile driver but the current # session uses the gcs driver, convert the ckpt's driver to gcs. - if ckpt_dir.startswith('gs://'): + if os.fspath(ckpt_dir).startswith('gs://'): ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) # If a ckpt was saved in gcs and is being loaded locally, then convert the # driver to file or gfile. If the ckpt was not saved in gcs, do not change. @@ -1716,7 +1723,9 @@ def fake_param_info(maybe_tspec: Any) -> Optional[_ParameterInfo]: axes=None) -def find_checkpoint(path: str, step: Optional[int] = None) -> str: +def find_checkpoint( + path: epath.PathLike, step: Optional[int] = None +) -> epath.PathLike: """Find the checkpoint file based on paths and steps. Args: @@ -1791,7 +1800,7 @@ def load_t5x_checkpoint( # If reading a ckpt that was written with gfile driver but the current # session uses the gcs driver, convert the ckpt's driver to gcs. - if path.startswith('gs://'): + if os.fspath(path).startswith('gs://'): ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) # If a ckpt was saved in gcs and is being loaded locally, then convert the # driver to file or gfile. If the ckpt was not saved in gcs, do not change. @@ -2195,6 +2204,9 @@ def best_fn(metrics): step_prefix='checkpoint', ) options.metric_name_to_monitor = metric_name_to_monitor + + if not gfile.isdir(directory): + directory = os.path.dirname(directory) self._manager = self._CheckpointManagerImpl( directory=directory, checkpointers=checkpointers, options=options ) @@ -2315,6 +2327,36 @@ def restore( if path is not None: directory, step = get_step_from_checkpoint_dir(os.fspath(path)) + # Check for legacy T5X checkpoint: If so, use legacy T5X + # checkpointer to restore the state. The following exclusive features of T5X + # checkpoint are skipped: DatasetIterator, [add more here when discovered] + try: + ckpt_path = find_checkpoint(directory, step) + except ValueError: + # `find_checkpoint` fails if the `.checkpoint` file isn't directly in + # the checkpoint directory. In this case, leave path as None and skip + # the legacy T5X checkpoint check. + ckpt_path = None + + if ckpt_path is not None: + ckpt_type = checkpoint_utils.detect_checkpoint_type( + ckpt_path, expected=checkpoint_utils.CheckpointTypes.ORBAX + ) + if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF: + raise ValueError( + 'Attempting to restore a TensorFlow checkpoint as a native T5X ' + 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + + ckpt_path + ) + elif ckpt_type is checkpoint_utils.CheckpointTypes.T5X: + legacy_checkpointer = Checkpointer( + self._train_state, + self._partitioner, + self.directory, + restore_dtype=self._restore_dtype, + ) + return legacy_checkpointer.restore(path=path) + state_dict = self._train_state.state_dict() # Returns a state dict rather than a train state. param_infos = _construct_orbax_param_infos( diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 new file mode 100644 index 000000000..a456e60f9 Binary files /dev/null and b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 differ diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 new file mode 100644 index 000000000..207b68275 Binary files /dev/null and b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 differ diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 new file mode 100644 index 000000000..437babcba Binary files /dev/null and b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 differ diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 new file mode 100644 index 000000000..437babcba Binary files /dev/null and b/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 differ diff --git a/t5x/testdata/tiny_orbax/1/checkpoint b/t5x/testdata/tiny_orbax/1/checkpoint new file mode 100644 index 000000000..479f5c61b --- /dev/null +++ b/t5x/testdata/tiny_orbax/1/checkpoint @@ -0,0 +1 @@ +„ª_optimizer‚¥state‚¬param_statesƒ¤biasÙ0PLACEHOLDER://_optimizer.state.param_states.bias¦kernelÙ2PLACEHOLDER://_optimizer.state.param_states.kernel¤stepÙ0PLACEHOLDER://_optimizer.state.param_states.step¤stepÙ#PLACEHOLDER://_optimizer.state.step¦target€­flax_mutablesÀ²flax_mutables_axesÀ«params_axesÀ \ No newline at end of file diff --git a/t5x/testdata/tiny_t5/checkpoint_1/checkpoint b/t5x/testdata/tiny_t5/checkpoint_1/checkpoint new file mode 100644 index 000000000..5ddf7bb96 Binary files /dev/null and b/t5x/testdata/tiny_t5/checkpoint_1/checkpoint differ diff --git a/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 new file mode 100644 index 000000000..9f6100a0c Binary files /dev/null and b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 differ diff --git a/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 new file mode 100644 index 000000000..c30283587 Binary files /dev/null and b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 differ diff --git a/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 new file mode 100644 index 000000000..872eef28e Binary files /dev/null and b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 differ