From 22036d7851f78aac72bdad714992a55752902e02 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Sun, 23 Jul 2023 03:17:04 -0700 Subject: [PATCH] Add branch to detect T5 checkpoints in OrbaxCheckpointManager which checks if the checkpoint is T5. In the case that it is, the legacy checkpointer is used to load the checkpoint. PiperOrigin-RevId: 550325349 --- t5x/checkpoint_utils.py | 159 +++++++++++++++++- t5x/checkpoint_utils_test.py | 41 +++++ t5x/checkpoints.py | 96 ++++++++--- .../1/_optimizer.state.param_states.bias/0 | Bin 0 -> 19 bytes .../_optimizer.state.param_states.kernel/0.0 | Bin 0 -> 409 bytes .../1/_optimizer.state.param_states.step/0 | Bin 0 -> 13 bytes .../tiny_orbax/1/_optimizer.state.step/0 | Bin 0 -> 13 bytes t5x/testdata/tiny_orbax/1/checkpoint | 1 + t5x/testdata/tiny_t5/checkpoint_1/checkpoint | Bin 0 -> 488 bytes .../checkpoint_1/state.param_states.bias/0 | Bin 0 -> 26 bytes .../state.param_states.kernel/0.0 | Bin 0 -> 423 bytes .../checkpoint_1/state.param_states.step/0 | Bin 0 -> 24 bytes 12 files changed, 269 insertions(+), 28 deletions(-) create mode 100644 t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 create mode 100644 t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 create mode 100644 t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 create mode 100644 t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 create mode 100644 t5x/testdata/tiny_orbax/1/checkpoint create mode 100644 t5x/testdata/tiny_t5/checkpoint_1/checkpoint create mode 100644 t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 create mode 100644 t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 create mode 100644 t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index 7fe256461..60e517dbf 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -18,16 +18,21 @@ removal process. """ +import enum import os +from typing import Any, BinaryIO, Optional from absl import logging - +from etils import epath +import msgpack from tensorflow.io import gfile # PINNED file in the checkpoint directory indicates that the checkpoint should # not be removed during the automatic pruning of old checkpoints. _PINNED_CHECKPOINT_FILENAME = 'PINNED' +PyTree = Any + def pinned_checkpoint_filepath(ckpt_dir: str) -> str: """Full path of the pinned checkpoint file.""" @@ -89,3 +94,155 @@ def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None: gfile.remove(file) else: logging.info('Keeping pinned checkpoint: %s', ckpt_dir) + + +def _read_msgpack_keys(file_like: BinaryIO) -> PyTree: + """Returns a tree containing all keys but no values from a msgpack file.""" + unpacker = msgpack.Unpacker(file_like) + num_keys = unpacker.read_map_header() + ret = {} + + # Contains references to the parent tree for each key to visit in the + # msgpack file traversal. + visit_stack = [ret for _ in range(num_keys)] + while visit_stack: + parent_dict = visit_stack.pop() + key = unpacker.unpack() + if isinstance(key, bytes): + key = str(unpacker.unpack(), 'utf-8') + + # Check if the value object is map. + try: + n = unpacker.read_map_header() + ref = parent_dict[key] = {} + visit_stack.extend(ref for _ in range(n)) + except msgpack.UnpackValueError: + # Not a map so skip unpacking the value object and record the current key. + unpacker.skip() + parent_dict[key] = None + + return ret + + +def _contains_ts_spec(tree: PyTree) -> bool: + """Returns whether the a Pytree contains a serialized ts.Spec object.""" + to_visit = [tree] + while to_visit: + cur = to_visit.pop() + if cur.keys() >= {'driver', 'kvstore', 'metadata'}: + return True + to_visit.extend(v for v in cur.values() if isinstance(v, dict)) + return False + + +# Constant copied from orbax/checkpoint/pytree_checkpoint_handler.py +_METADATA_FILE = '_METADATA' + + +def _contains_orbax_metadata(ckpt_path: str) -> bool: + metadata = os.path.join(os.path.dirname(ckpt_path), _METADATA_FILE) + return gfile.exists(metadata) + + +class CheckpointTypes(enum.Enum): + ORBAX = 'ORBAX' + T5X = 'T5X' + T5X_TF = 'T5X_TF' + + +def _warn_if_unexpected_type( + checkpoint_path, checkpoint_type, expected, extra_warn_log +): + """Warns the user if unexpected type found.""" + if expected is None or checkpoint_type == expected: + return + + logging.warning( + 'Expected the checkpoint at %s to be %s format, but' + ' the actual detected format was %s.', + checkpoint_path, + expected, + checkpoint_type, + ) + logging.warning(extra_warn_log) + + +def detect_checkpoint_type( + checkpoint_path: epath.PathLike, expected: Optional[CheckpointTypes] = None +) -> CheckpointTypes: + """Returns the checkpoint type by reading the `.checkpoint` metadata file. + + Args: + checkpoint_path: The path of the `.checkpoint` file. + expected: The expected checkpoint type. If the checkpoint type is not as + expected, this function will log a warning but will not raise an error. + + Returns: + The checkpoint type. + """ + if _contains_orbax_metadata(checkpoint_path): + checkpoint_type = CheckpointTypes.ORBAX + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + f'Found `{_METADATA_FILE}` in the checkpoint directory, which only ' + 'appears in Orbax checkpoints', + ) + return checkpoint_type + + with gfile.GFile(checkpoint_path, 'rb') as fp: + raw_contents = fp.read(21) + if raw_contents == b'model_checkpoint_path': + checkpoint_type = CheckpointTypes.T5X_TF + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'The checkpoint file was not a msgpack, and had the string ' + '"model_checkpoint_path", so it was assumed to be in the T5X ' + 'TensorFlow format.', + ) + return checkpoint_type + + # Assume that if the msgpack file has exactly 'version' and 'optimizer' as + # keys, it is a T5X checkpoint. Checkpoints that were created a long time + # ago may not contain these keys, so there is a backup ts.Spec check + # as well. + fp.seek(0) + key_tree = _read_msgpack_keys(fp) + if set(key_tree.keys()) == {'version', 'optimizer'}: + checkpoint_type = CheckpointTypes.T5X + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'Top-level keys in the msgpack file were "version" and "optimizer", ' + 'thus the checkpoint was assumed to be in the T5X format.', + ) + return checkpoint_type + elif _contains_ts_spec(key_tree): + # If the checkpoint contains a ts.Spec, it could either be a T5X + # checkpoint or an early version Flax checkpoint. The latter is + # essentially deprecated but should also be handled by the T5X + # Checkpointer, so we return T5X here for simplicity. + checkpoint_type = CheckpointTypes.T5X + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'Found ts.Spec in the checkpoint msgpack file, thus the checkpoint' + ' was assumed to be in the T5X format.', + ) + return checkpoint_type + else: + checkpoint_type = CheckpointTypes.ORBAX + _warn_if_unexpected_type( + checkpoint_path, + checkpoint_type, + expected, + 'Did not detect ts.Spec nor the {"version", "optimizer"} keys in the' + 'checkpoint msgpack file, so the checkpoint was assumed to be ' + 'written with Orbax.', + ) + return checkpoint_type diff --git a/t5x/checkpoint_utils_test.py b/t5x/checkpoint_utils_test.py index fe2deaf72..ef54689f3 100644 --- a/t5x/checkpoint_utils_test.py +++ b/t5x/checkpoint_utils_test.py @@ -145,5 +145,46 @@ def test_remove_dataset_checkpoint_pinned(self): self.assertTrue(gfile.exists(self.train_ds_file)) self.assertTrue(gfile.exists(self.ckpt_dir_path)) + def test_detect_checkpoint_type(self): + tf_ckpt = os.path.join(TESTDATA, "mtf_tiny_t5", "checkpoint") + orbax_ckpt = os.path.join(TESTDATA, "tiny_orbax", "1", "checkpoint") + t5_ckpt = os.path.join(TESTDATA, "tiny_t5", "checkpoint_1", "checkpoint") + + ret = checkpoint_utils.detect_checkpoint_type( + t5_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X + ) + self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X) + + ret = checkpoint_utils.detect_checkpoint_type( + tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF + ) + self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X_TF) + + ret = checkpoint_utils.detect_checkpoint_type( + orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.ORBAX + ) + self.assertEqual(ret, checkpoint_utils.CheckpointTypes.ORBAX) + + with self.assertLogs(level="WARN") as log_output: + checkpoint_utils.detect_checkpoint_type( + tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X + ) + self.assertRegex( + log_output[0][0].message, + ".*to be CheckpointTypes.T5X format, but the actual detected format was" + " CheckpointTypes.T5X_TF.*", + ) + + with self.assertLogs(level="WARN") as log_output: + checkpoint_utils.detect_checkpoint_type( + orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF + ) + self.assertRegex( + log_output[0][0].message, + ".*to be CheckpointTypes.T5X_TF format, but the actual detected format" + " was CheckpointTypes.ORBAX.*", + ) + + if __name__ == "__main__": absltest.main() diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 8c4f825af..1842a8558 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -228,10 +228,10 @@ def _sync_global_devices(name: str) -> None: def get_checkpoint_dir( - checkpoints_dir: str, + checkpoints_dir: epath.PathLike, step: int, step_format_fixed_length: Optional[int] = None, -) -> str: +) -> epath.PathLike: """Returns path to a checkpoint dir given a parent directory and step.""" step_str = ( f'{step:0{step_format_fixed_length}d}' @@ -366,7 +366,7 @@ def _get_spec(directory: str, arr: Any, name: str, metadata: Dict[str, Any]) -> ts.Spec: """Get ts.Spec from array and name information.""" - if directory.startswith('gs://'): + if os.fspath(directory).startswith('gs://'): spec = { 'driver': 'zarr', 'dtype': jnp.dtype(arr.dtype).name, @@ -555,18 +555,20 @@ class Checkpointer(object): oldest ones will be automatically deleted to save space. """ - def __init__(self, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - checkpoints_dir: str, - dataset_iterator: Optional[ - Union[tf.data.Iterator, - clu.data.dataset_iterator.DatasetIterator]] = None, - *, - keep: Optional[int] = None, - save_dtype: jnp.dtype = np.float32, - restore_dtype: Optional[jnp.dtype] = None, - keep_dataset_checkpoints: Optional[int] = None): + def __init__( + self, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + checkpoints_dir: epath.PathLike, + dataset_iterator: Optional[ + Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] + ] = None, + *, + keep: Optional[int] = None, + save_dtype: jnp.dtype = np.float32, + restore_dtype: Optional[jnp.dtype] = None, + keep_dataset_checkpoints: Optional[int] = None, + ): """Checkpointer constructor. Args: @@ -708,7 +710,7 @@ def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec): self._get_state_dict_for_save(self._train_state.state_dict()), self._partitioner.get_mesh_axes(self._train_state).state_dict()) - def _get_checkpoint_dir(self, step: int) -> str: + def _get_checkpoint_dir(self, step: int) -> epath.PathLike: return get_checkpoint_dir(self.checkpoints_dir, step) def all_steps(self) -> Sequence[int]: @@ -1035,25 +1037,30 @@ def restore( if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path): raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}') + ckpt_type = checkpoint_utils.detect_checkpoint_type( + ckpt_path, expected=checkpoint_utils.CheckpointTypes.T5X + ) + if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF: + raise ValueError( + 'Attempting to restore a TensorFlow checkpoint as a native T5X ' + 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + f'{ckpt_path}' + ) + # Don't error out here for Orbax-detected checkpoint (there are edge cases + # where all values are stored in the msgpack and the checkpoint file can be + # loaded by both the Orbax and T5X checkpointer). logging.info('Restoring from checkpoint: %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: # TODO(adarob): Use threaded reading as in flax.checkpoints. - raw_contents = fp.read() - if raw_contents.startswith(b'model_checkpoint_path'): - raise ValueError( - 'Attempting to restore a TensorFlow checkpoint as a native T5X ' - 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + - ckpt_path) - # `ckpt_contents['optimizer']` is a pytree with a realized np.array for # leaves (params or states) written as msgpack and a ts.Spec (in a dict) # for leaves written by TensorStore. - ckpt_contents = serialization.msgpack_restore(raw_contents) + ckpt_contents = serialization.msgpack_restore(fp.read()) # If reading a ckpt that was written with gfile driver but the current # session uses the gcs driver, convert the ckpt's driver to gcs. - if ckpt_dir.startswith('gs://'): + if os.fspath(ckpt_dir).startswith('gs://'): ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) # If a ckpt was saved in gcs and is being loaded locally, then convert the # driver to file or gfile. If the ckpt was not saved in gcs, do not change. @@ -1716,7 +1723,9 @@ def fake_param_info(maybe_tspec: Any) -> Optional[_ParameterInfo]: axes=None) -def find_checkpoint(path: str, step: Optional[int] = None) -> str: +def find_checkpoint( + path: epath.PathLike, step: Optional[int] = None +) -> epath.PathLike: """Find the checkpoint file based on paths and steps. Args: @@ -1791,7 +1800,7 @@ def load_t5x_checkpoint( # If reading a ckpt that was written with gfile driver but the current # session uses the gcs driver, convert the ckpt's driver to gcs. - if path.startswith('gs://'): + if os.fspath(path).startswith('gs://'): ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) # If a ckpt was saved in gcs and is being loaded locally, then convert the # driver to file or gfile. If the ckpt was not saved in gcs, do not change. @@ -2195,6 +2204,9 @@ def best_fn(metrics): step_prefix='checkpoint', ) options.metric_name_to_monitor = metric_name_to_monitor + + if not gfile.isdir(directory): + directory = os.path.dirname(directory) self._manager = self._CheckpointManagerImpl( directory=directory, checkpointers=checkpointers, options=options ) @@ -2315,6 +2327,36 @@ def restore( if path is not None: directory, step = get_step_from_checkpoint_dir(os.fspath(path)) + # Check for legacy T5X checkpoint: If so, use legacy T5X + # checkpointer to restore the state. The following exclusive features of T5X + # checkpoint are skipped: DatasetIterator, [add more here when discovered] + try: + ckpt_path = find_checkpoint(directory, step) + except ValueError: + # `find_checkpoint` fails if the `.checkpoint` file isn't directly in + # the checkpoint directory. In this case, leave path as None and skip + # the legacy T5X checkpoint check. + ckpt_path = None + + if ckpt_path is not None: + ckpt_type = checkpoint_utils.detect_checkpoint_type( + ckpt_path, expected=checkpoint_utils.CheckpointTypes.ORBAX + ) + if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF: + raise ValueError( + 'Attempting to restore a TensorFlow checkpoint as a native T5X ' + 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + + ckpt_path + ) + elif ckpt_type is checkpoint_utils.CheckpointTypes.T5X: + legacy_checkpointer = Checkpointer( + self._train_state, + self._partitioner, + self.directory, + restore_dtype=self._restore_dtype, + ) + return legacy_checkpointer.restore(path=path) + state_dict = self._train_state.state_dict() # Returns a state dict rather than a train state. param_infos = _construct_orbax_param_infos( diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 new file mode 100644 index 0000000000000000000000000000000000000000..a456e60f98c70763c791e0171116001df557d368 GIT binary patch literal 19 acmdPcs{faPA(VkZfq|jHo{>Shj{^WM(F6DZ literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 new file mode 100644 index 0000000000000000000000000000000000000000..207b68275d401f7e5180d0c9a34f183769ecc8e8 GIT binary patch literal 409 zcmV;K0cQRvwJ-eu0D%kuIN|0W$v?wC^2}bnMr`IkZ$?us8x!C^e~BNdnWW`EjsaG> z;KSNKYX0|C7p>b?Wq!jHxtPnvRf@j%3@#ShT zEBxR;)Kd3Bz`@Wz@J<}fd$h(sZC3Y2r)T6pZ*|?jXq@9e>b-lg#=q@92-#sxbLqZ6 z3L}{O=QP$o&_ya{fgRC5>X;yV#4FT4ls_a*M{UPHILZ-hP;%=(Fx^hrZkXUdUz|{j zPDtxNB>a9k8yCqx&?7F*S)!>wytL}e2T;jMkXy8iA?H0oEg^WOSIiTP+Y-R z(9Fj_Kv;m9EJ5i%&gcJ><2>#^CLVjvLr~a1oXoITAbHh4Y7c-f*}&*OVak>OJlX9( D>buXP literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 new file mode 100644 index 0000000000000000000000000000000000000000..437babcba202ff9cfb2ae79f844c2cf025332a0a GIT binary patch literal 13 UcmdPcs{faPL6L!hk%55$02gBdEC2ui literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 b/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 new file mode 100644 index 0000000000000000000000000000000000000000..437babcba202ff9cfb2ae79f844c2cf025332a0a GIT binary patch literal 13 UcmdPcs{faPL6L!hk%55$02gBdEC2ui literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_orbax/1/checkpoint b/t5x/testdata/tiny_orbax/1/checkpoint new file mode 100644 index 000000000..479f5c61b --- /dev/null +++ b/t5x/testdata/tiny_orbax/1/checkpoint @@ -0,0 +1 @@ +„ª_optimizer‚¥state‚¬param_statesƒ¤biasÙ0PLACEHOLDER://_optimizer.state.param_states.bias¦kernelÙ2PLACEHOLDER://_optimizer.state.param_states.kernel¤stepÙ0PLACEHOLDER://_optimizer.state.param_states.step¤stepÙ#PLACEHOLDER://_optimizer.state.step¦target€­flax_mutablesÀ²flax_mutables_axesÀ«params_axesÀ \ No newline at end of file diff --git a/t5x/testdata/tiny_t5/checkpoint_1/checkpoint b/t5x/testdata/tiny_t5/checkpoint_1/checkpoint new file mode 100644 index 0000000000000000000000000000000000000000..5ddf7bb965f98df3ef31af740a397d4be1b06a6e GIT binary patch literal 488 zcmb8r!AiqG5P;#7+KUGtrGkpDBV#j7hHSFS%os2y+4=yfzJcr}Dn5XC7JMz+)I*`w z;&Ea2pO5i`At>Ga`HN}rDGJa?4KwNSwk)w ztkI03;1Gt|3pz$Vkw literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 new file mode 100644 index 0000000000000000000000000000000000000000..9f6100a0c356d11fb39b17061ffb8ee19a5184d3 GIT binary patch literal 26 ccmb2|=3oE==H!Ho4e1Vt8De+IMu`B007fPSod5s; literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 new file mode 100644 index 0000000000000000000000000000000000000000..c302835877f986d4b30ed1927128b93438a0070d GIT binary patch literal 423 zcmV;Y0a*SYiwFP!000000|Af$Z~i#p<{!yF!$0!OUcE+a=09&nQ!N`4;6HzfAE}w7 zpwsUa>{vy>p!1*Ly!+`7{$v>fSY)hmR?LVv#K^=l; z*+22+YA`GO;6Kz-_d&qH&_D1_9L;;Q#y@RV_eQ5@pw8vPS|dk z;6GoSP>W7T>pvv?emNT#$v@B|F3nk@sXx55>dWKH+CO9SW$lHy<3Fs2Zw=(K%|94& z_)?&hz&{&O|@1w5*NY?mxH{V#XFe#Xk#O>;&m>=s!j#Eg6YS>OY(r*62&L z-9J!V!B^1C$3H+=fSN2p=|9fr|CHlA?ms3Td(J~p*gu@iuvj2@)jw(vfG^p==s#h~ RmH<52?LUHfi%O6I005+i(jfo< literal 0 HcmV?d00001 diff --git a/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 b/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 new file mode 100644 index 0000000000000000000000000000000000000000..872eef28e9039ee55aff5824225abaf59c64a9e2 GIT binary patch literal 24 acmb2|=3oE==H!%w1O|r69Y1EW00jUwB?SWj literal 0 HcmV?d00001