Skip to content

Commit

Permalink
Add explicit parameters for torch.load (#6751)
Browse files Browse the repository at this point in the history
Successor PR to #6094:

> FutureWarning: You are using torch.load with weights_only=False (the
current default value), which uses the default pickle module implicitly.
It is possible to construct malicious pickle data which will execute
arbitrary code during unpickling (See
https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models
for more details). In a future release, the default value for
weights_only will be flipped to True. This limits the functions that
could be executed during unpickling. Arbitrary objects will no longer be
allowed to be loaded via this mode unless they are explicitly
allowlisted by the user via torch.serialization.add_safe_globals. We
recommend you start setting weights_only=True for any use case where you
don't have full control of the loaded file. Please open an issue on
GitHub for any issues related to this experimental feature.

Todo:
- [ ] Update values in non-test files to True where necessary.
  • Loading branch information
loadams authored Nov 19, 2024
1 parent 1fdad1f commit 2e0c39b
Show file tree
Hide file tree
Showing 18 changed files with 44 additions and 40 deletions.
17 changes: 10 additions & 7 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')

def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

Expand All @@ -137,14 +137,17 @@ def get_final_norm_layer_id(self):

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)

return self.global_state[ITERATION_KEY]

def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd_list = [
torch.load(fname, map_location=torch.device('cpu'), weights_only=False)
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list)
return sd

Expand All @@ -154,7 +157,7 @@ def get_embedding_files(self, tp_index: int) -> list:

def _get_checkpoint_value(self, key):
if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[key] = sd.get(key, None)

return self.global_state[key]
Expand All @@ -169,7 +172,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]

merged_sd = None
for sd in sd_list:
Expand All @@ -185,7 +188,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
Expand All @@ -196,7 +199,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list:

def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False)
return sd

def get_final_norm_files(self, tp_index: int) -> list:
Expand Down
14 changes: 7 additions & 7 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):


def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)

flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
Expand Down Expand Up @@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]
shards = [torch.load(p, weights_only=False) for p in paths]

if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
Expand Down Expand Up @@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size):


def _parse_model_states_stage3(files):
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES]


def _save_optimizer_state(args, ds_checkpoint):
Expand All @@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint):


def _save_optimizer_state_stage3(args, optim_files):
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
output_sd = sd[OPTIMIZER_STATE_DICT]
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
Expand All @@ -446,15 +446,15 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern):


def _get_zero_stage(optim_files):
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
zero_stage = optimizer_state.get(ZERO_STAGE, 1)
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
Expand Down Expand Up @@ -488,7 +488,7 @@ def main(args):

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False)
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
ckpt_dict = torch.load(ckpt_file, weights_only=False)

if key == "step":
step = ckpt_dict
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
sd = torch.load(state_file, map_location=torch.device('cpu'))
sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False)
for key in keys_to_ignore:
sd.pop(key, None)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,15 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False)
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False)
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool:
else:
model_param_json_fname = "pytorch_model.bin.index.json"
model_file_fname = "pytorch_model.bin"
self._checkpoint_load_fn = partial(torch.load, map_location="cpu")
self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False)

model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def populate_model_parameters(self) -> None:
buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)

buffer = torch.load(buffer_path)
buffer = torch.load(buffer_path, weights_only=False)
metadata = json.load(open(metadata_path, "r"))
metadata = ModelMetadata.parse_raw(metadata)

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")

for i in range(len(checkpoint)):
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
Expand All @@ -437,7 +437,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count)
]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand All @@ -457,7 +457,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand Down Expand Up @@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
from safetensors.torch import load_file
sd = load_file(checkpoint)
else:
sd = torch.load(checkpoint, map_location='cpu')
sd = torch.load(checkpoint, map_location='cpu', weights_only=False)

policy = {}
if orig_class is not None:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=False)

self._load_global_state(optim_sd)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load(self, path: str, map_location=None):
if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save(self, state_dict, path: str):

def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Torch] Loaded checkpoint from {path}.")
return partition

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'

optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)

key_list = ["fp32", "exp_avg", "exp_avg_sq"]
Expand Down Expand Up @@ -2799,7 +2799,7 @@ def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()

# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1)
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)

# Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group)
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_model_state_files(checkpoint_dir):
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
state_dict = torch.load(file, map_location=device, weights_only=False)

if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
Expand Down Expand Up @@ -149,7 +149,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in tqdm(files, desc='Loading checkpoint shards'):
state_dict = torch.load(f, map_location=device, mmap=True)
state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f))
expert = torch.load(os.path.join(root, f), weights_only=False)
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
)

hidden_dim = 10
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)

ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l
model.load_checkpoint(tmpdir, load_optimizer_states=load_optim)

if load_optim:
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file))
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False)
curr_sd = model.optimizer.optimizer.state_dict()
compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys)

Expand Down Expand Up @@ -523,7 +523,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
all_ckpt_folder = os.path.join(tmpdir, 'all_params')
ds_engine.save_checkpoint(all_ckpt_folder)
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
loaded_all_param_model = torch.load(all_params_ckpt_file)['module']
loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module']
all_param_names = set([n for n, p in model.named_parameters()])
assert set(loaded_all_param_model.keys()) == all_param_names

Expand All @@ -536,7 +536,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
# Excluding frozen parameters should reduce checkpoint size
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)

loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module']
loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module']
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage):

custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank(
os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00')
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module']
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module']
loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys())

custom_state_dict_param_names = set([k for k, v in model.state_dict().items()])
Expand Down Expand Up @@ -618,7 +618,8 @@ def test_save_tensor_clone(self, tmpdir, zero_stage, use_cpu_device):
clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt')
torch.save(clone_state_dict, clone_ckpt_file)

compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file))
compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False),
torch.load(clone_ckpt_file, weights_only=False))


class TestZeRONonDistributed(DistributedTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir):
test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
if dist.get_rank() == 0:
load_path = os.path.join(class_tmpdir, "output.pt")
baseline = torch.load(load_path)
baseline = torch.load(load_path, weights_only=False)
test = test.cpu()
assert torch.allclose(
baseline, test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz
assert torch.is_tensor(test[0][0])
test = test[0][0].cpu()
load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
baseline = torch.load(load_path)
baseline = torch.load(load_path, weights_only=False)
assert torch.allclose(
baseline, test,
atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"
Expand Down

0 comments on commit 2e0c39b

Please sign in to comment.