From 1fdad1fa52f525d64132cb11e2746b06751efb22 Mon Sep 17 00:00:00 2001 From: baodi Date: Tue, 19 Nov 2024 09:38:27 -0800 Subject: [PATCH 1/2] make xpu ops compatible with oneapi 2025.0 (#6760) Compatibility update for xpu ops This PR introduces changes that will make xpu ops compatible with the OneAPI 2025.0 toolkit. This is an important update that will allow us to develop and ship our most demanding models on this innovative hardware. --------- Signed-off-by: baodii Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams --- csrc/xpu/adam/multi_tensor_apply.dp.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/xpu/adam/multi_tensor_apply.dp.hpp b/csrc/xpu/adam/multi_tensor_apply.dp.hpp index 01964e1a4d38..14a130e2a23e 100644 --- a/csrc/xpu/adam/multi_tensor_apply.dp.hpp +++ b/csrc/xpu/adam/multi_tensor_apply.dp.hpp @@ -109,6 +109,12 @@ class multi_tensor_apply_kernel { std::tuple args; }; +// to make sure multi_tensor_apply_kernel can be used in sycl::buffer +namespace sycl { +template +struct is_device_copyable> : std::true_type {}; +} // namespace sycl + template void multi_tensor_apply(int block_size, int chunk_size, From 2e0c39b55ce55dfeeff1224203d9035555e228d9 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:09:52 -0800 Subject: [PATCH 2/2] Add explicit parameters for torch.load (#6751) Successor PR to #6094: > FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. Todo: - [ ] Update values in non-test files to True where necessary. --- deepspeed/checkpoint/deepspeed_checkpoint.py | 17 ++++++++++------- deepspeed/checkpoint/ds_to_universal.py | 14 +++++++------- deepspeed/checkpoint/universal_checkpoint.py | 2 +- deepspeed/checkpoint/zero_checkpoint.py | 2 +- deepspeed/inference/engine.py | 4 ++-- .../v2/checkpoint/huggingface_engine.py | 2 +- .../inference_policy_base.py | 2 +- deepspeed/module_inject/replace_module.py | 8 ++++---- deepspeed/runtime/base_optimizer.py | 2 +- .../nebula_checkpoint_engine.py | 2 +- .../torch_checkpoint_engine.py | 2 +- deepspeed/runtime/zero/stage3.py | 4 ++-- deepspeed/utils/zero_to_fp32.py | 4 ++-- tests/unit/checkpoint/common.py | 2 +- .../checkpoint/test_universal_checkpoint.py | 2 +- tests/unit/checkpoint/test_zero_optimizer.py | 11 ++++++----- .../test_configurable_parallel_mp.py | 2 +- .../test_configurable_parallel_pp.py | 2 +- 18 files changed, 44 insertions(+), 40 deletions(-) diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 31997177a262..9a368b7a0a25 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -116,7 +116,7 @@ def show_transformer_file_map(self): self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files') def _build_global_state(self): - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) @@ -137,14 +137,17 @@ def get_final_norm_layer_id(self): def get_iteration(self): if not ITERATION_KEY in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) return self.global_state[ITERATION_KEY] def get_embedding_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_embedding_map.keys() - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] + sd_list = [ + torch.load(fname, map_location=torch.device('cpu'), weights_only=False) + for fname in self.tp_to_embedding_map[tp_index] + ] sd = self._merge_state_dicts(sd_list) return sd @@ -154,7 +157,7 @@ def get_embedding_files(self, tp_index: int) -> list: def _get_checkpoint_value(self, key): if not key in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[key] = sd.get(key, None) return self.global_state[key] @@ -169,7 +172,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: assert tp_index < self.tp_degree assert pp_index < self.pp_degree fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list] merged_sd = None for sd in sd_list: @@ -185,7 +188,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert pp_index < self.pp_degree t_list = [] for fname_list in self.transformer_file_map[(tp_index, pp_index)]: - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list] sd = self._merge_state_dicts(sd_list) t_list.append(sd) return t_list @@ -196,7 +199,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list: def get_final_norm_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_final_norm_map.keys() - sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) + sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False) return sd def get_final_norm_files(self, tp_index: int) -> list: diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index e5974a30df22..f7b75eee66d0 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): - state_dict = torch.load(optim_files[dp_index], map_location='cpu') + state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False) flat_state = dict( exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], @@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None): raise ValueError(f"Cannot parse dp_rank from {p}") paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] - shards = [torch.load(p) for p in paths] + shards = [torch.load(p, weights_only=False) for p in paths] if state == "step": assert all(v == shards[0] for v in shards), "All shards must have the same step value" @@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size): def _parse_model_states_stage3(files): - return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES] + return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES] def _save_optimizer_state(args, ds_checkpoint): @@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint): def _save_optimizer_state_stage3(args, optim_files): - sd = torch.load(optim_files[0], map_location=torch.device('cpu')) + sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) output_sd = sd[OPTIMIZER_STATE_DICT] output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS] zero_output_folder = os.path.join(args.output_folder, "zero") @@ -446,7 +446,7 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern): def _get_zero_stage(optim_files): - state_dict = torch.load(optim_files[0], map_location=torch.device('cpu')) + state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) optimizer_state = state_dict[OPTIMIZER_STATE_DICT] zero_stage = optimizer_state.get(ZERO_STAGE, 1) return zero_stage @@ -454,7 +454,7 @@ def _get_zero_stage(optim_files): def _inject_missing_state(ds_checkpoint): if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: - sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) if UNIVERSAL_CHECKPOINT_INFO not in sd: ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][ @@ -488,7 +488,7 @@ def main(args): slice_shapes = [] for mp_rank_file in ds_checkpoint.mp_rank_files: - mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) + mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False) slice_shapes += mp_sd[PARAM_SHAPES] # fix back to normal flat dict, merge duplicates for tp>1 diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index 064891a8bb54..266d5a063595 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): step = None for key in hp_keys: ckpt_file = os.path.join(folder, f"{key}.pt") - ckpt_dict = torch.load(ckpt_file) + ckpt_dict = torch.load(ckpt_file, weights_only=False) if key == "step": step = ckpt_dict diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 6730b93dfd4f..c85f0241005d 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) merged_sd = None for state_file in state_file_list: - sd = torch.load(state_file, map_location=torch.device('cpu')) + sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False) for key in keys_to_ignore: sd.pop(key, None) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 6574d49fb132..cfca1ff4fe4c 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -452,7 +452,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): checkpoint = sd_loader['checkpoints'] if type(checkpoint) is list: - self.sd = torch.load(checkpoint[0], map_location='cpu') + self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) @@ -460,7 +460,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): for i in range(1, len(checkpoint)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") - self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name()) + self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index d88d99ebebfd..b17bb886838f 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -80,7 +80,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool: else: model_param_json_fname = "pytorch_model.bin.index.json" model_file_fname = "pytorch_model.bin" - self._checkpoint_load_fn = partial(torch.load, map_location="cpu") + self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False) model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname) diff --git a/deepspeed/inference/v2/model_implementations/inference_policy_base.py b/deepspeed/inference/v2/model_implementations/inference_policy_base.py index d5a326c03599..2f4266a8cb88 100644 --- a/deepspeed/inference/v2/model_implementations/inference_policy_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_policy_base.py @@ -205,7 +205,7 @@ def populate_model_parameters(self) -> None: buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) - buffer = torch.load(buffer_path) + buffer = torch.load(buffer_path, weights_only=False) metadata = json.load(open(metadata_path, "r")) metadata = ModelMetadata.parse_raw(metadata) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 1c5745dcf168..7afe6ca903fb 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -415,7 +415,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): - sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')] + sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)] load_model_with_checkpoint(replaced_module, sd, mp_replace, @@ -437,7 +437,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j] for j in range(sd_count) ] - sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files] + sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files] load_model_with_checkpoint(replaced_module, sds, mp_replace, @@ -457,7 +457,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): pbar.update(1) ckpt_file = os.path.join(base_dir1, checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i] - sds = [torch.load(ckpt_file, map_location='cpu')] + sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)] load_model_with_checkpoint(replaced_module, sds, mp_replace, @@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No from safetensors.torch import load_file sd = load_file(checkpoint) else: - sd = torch.load(checkpoint, map_location='cpu') + sd = torch.load(checkpoint, map_location='cpu', weights_only=False) policy = {} if orig_class is not None: diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 6cfd66f1cc38..b8df7499450d 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -22,7 +22,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") assert os.path.isfile( optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path) + optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state(optim_sd) diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index e26e3243c4b5..e834bf0d22d7 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -58,7 +58,7 @@ def load(self, path: str, map_location=None): if not self.enable_nebula_load and first_load_flag: self.tag_flag = tag logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...") - partition = torch.load(path, map_location=map_location) + partition = torch.load(path, map_location=map_location, weights_only=False) logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .") return partition diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index 5cd44864bb2e..076c638532ad 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -25,7 +25,7 @@ def save(self, state_dict, path: str): def load(self, path: str, map_location=None): logger.info(f"[Torch] Loading checkpoint from {path}...") - partition = torch.load(path, map_location=map_location) + partition = torch.load(path, map_location=map_location, weights_only=False) logger.info(f"[Torch] Loaded checkpoint from {path}.") return partition diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 04d52319ae8c..99a5ecf41a2f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2741,7 +2741,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa assert os.path.isfile( optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path) + optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state_stage3(optim_sd) key_list = ["fp32", "exp_avg", "exp_avg_sq"] @@ -2799,7 +2799,7 @@ def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() # Load tensors from files and reshape them to flat vectors - loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1) + loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) # Partition the loaded data according to the local rank world_size = dist.get_world_size(group=self.dp_process_group) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index c0768deae62b..e93cb1c95f15 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -102,7 +102,7 @@ def get_model_state_files(checkpoint_dir): def parse_model_states(files): zero_model_states = [] for file in files: - state_dict = torch.load(file, map_location=device) + state_dict = torch.load(file, map_location=device, weights_only=False) if BUFFER_NAMES not in state_dict: raise ValueError(f"{file} is not a model state checkpoint") @@ -149,7 +149,7 @@ def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] for f in tqdm(files, desc='Loading checkpoint shards'): - state_dict = torch.load(f, map_location=device, mmap=True) + state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False) # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index 3fb13b214ea0..001c08f1a99f 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict, for root, _, files in os.walk(save_folder): for f in files: if "_expert_" in f and "_model_states" in f: - expert = torch.load(os.path.join(root, f)) + expert = torch.load(os.path.join(root, f), weights_only=False) needed, storages = 0, {} for name, tensor in expert.items(): needed += tensor.size().numel() diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py index 27ddf0cdef39..46d4294bdd0d 100644 --- a/tests/unit/checkpoint/test_universal_checkpoint.py +++ b/tests/unit/checkpoint/test_universal_checkpoint.py @@ -181,7 +181,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam): ) hidden_dim = 10 - loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt") + loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False) ds_config["checkpoint"] = {"load_universal": True} univ_model = SimpleModel(hidden_dim) diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index 84b4eca6e2ca..44966b331d0f 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -264,7 +264,7 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file)) + saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False) curr_sd = model.optimizer.optimizer.state_dict() compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys) @@ -523,7 +523,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage): all_ckpt_folder = os.path.join(tmpdir, 'all_params') ds_engine.save_checkpoint(all_ckpt_folder) all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00') - loaded_all_param_model = torch.load(all_params_ckpt_file)['module'] + loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module'] all_param_names = set([n for n, p in model.named_parameters()]) assert set(loaded_all_param_model.keys()) == all_param_names @@ -536,7 +536,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage): # Excluding frozen parameters should reduce checkpoint size assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file) - loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module'] + loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module'] frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad]) loaded_trainable_param_names = set(loaded_trainable_param_model.keys()) overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names) @@ -575,7 +575,7 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage): custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank( os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00') - loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module'] + loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module'] loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys()) custom_state_dict_param_names = set([k for k, v in model.state_dict().items()]) @@ -618,7 +618,8 @@ def test_save_tensor_clone(self, tmpdir, zero_stage, use_cpu_device): clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt') torch.save(clone_state_dict, clone_ckpt_file) - compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file)) + compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False), + torch.load(clone_ckpt_file, weights_only=False)) class TestZeRONonDistributed(DistributedTest): diff --git a/tests/unit/model_parallelism/test_configurable_parallel_mp.py b/tests/unit/model_parallelism/test_configurable_parallel_mp.py index cca1ef3584ad..a7b0d3431ee9 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_mp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_mp.py @@ -170,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir): test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name)) if dist.get_rank() == 0: load_path = os.path.join(class_tmpdir, "output.pt") - baseline = torch.load(load_path) + baseline = torch.load(load_path, weights_only=False) test = test.cpu() assert torch.allclose( baseline, test, diff --git a/tests/unit/model_parallelism/test_configurable_parallel_pp.py b/tests/unit/model_parallelism/test_configurable_parallel_pp.py index e50fd18577b1..df469044e186 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_pp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_pp.py @@ -225,7 +225,7 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz assert torch.is_tensor(test[0][0]) test = test[0][0].cpu() load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt") - baseline = torch.load(load_path) + baseline = torch.load(load_path, weights_only=False) assert torch.allclose( baseline, test, atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"