diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 315604c87..39118f050 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2165,7 +2165,12 @@ def _modify_orbax_param_info(info, value): ) return info - item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args) + item_ = jax.tree.map( + lambda x, y: None if x is None else _make_orbax_internal_metadata(x, y), + item_, + restore_args, + is_leaf=lambda x: x is None, + ) param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_) param_infos_ = jax.tree.map( _modify_orbax_param_info, param_infos_, state_dict_to_restore