Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weights_only=True in torch.load #6094

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')

def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @terry-for-github - could you run the pre-commit formatter, that should fix the formatting errors on these files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll try that.

sd = torch.load(self.mp_rank_files[0], weights_only=True, map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

Expand All @@ -137,14 +137,14 @@ def get_final_norm_layer_id(self):

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], weights_only=True, map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)

return self.global_state[ITERATION_KEY]

def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd_list = [torch.load(fname, weights_only=True, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd = self._merge_state_dicts(sd_list)
return sd

Expand All @@ -154,7 +154,7 @@ def get_embedding_files(self, tp_index: int) -> list:

def _get_checkpoint_value(self, key):
if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], weights_only=True, map_location=torch.device('cpu'))
self.global_state[key] = sd.get(key, None)

return self.global_state[key]
Expand All @@ -169,7 +169,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, weights_only=True, map_location=torch.device('cpu')) for fname in fname_list]

merged_sd = None
for sd in sd_list:
Expand All @@ -185,7 +185,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, weights_only=True, map_location=torch.device('cpu')) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
Expand All @@ -196,7 +196,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list:

def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], weights_only=True, map_location=torch.device('cpu'))
return sd

def get_final_norm_files(self, tp_index: int) -> list:
Expand Down
14 changes: 7 additions & 7 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):


def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
state_dict = torch.load(optim_files[dp_index], weights_only=True, map_location='cpu')

flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
Expand Down Expand Up @@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]
shards = [torch.load(p, weights_only=True) for p in paths]

if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
Expand Down Expand Up @@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size):


def _parse_model_states_stage3(files):
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
return torch.load(files[0], weights_only=True, map_location=torch.device('cpu'))[PARAM_SHAPES]


def _save_optimizer_state(args, ds_checkpoint):
Expand All @@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint):


def _save_optimizer_state_stage3(args, optim_files):
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
sd = torch.load(optim_files[0], weights_only=True, map_location=torch.device('cpu'))
output_sd = sd[OPTIMIZER_STATE_DICT]
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
Expand All @@ -446,15 +446,15 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern):


def _get_zero_stage(optim_files):
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
state_dict = torch.load(optim_files[0], weights_only=True, map_location=torch.device('cpu'))
optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
zero_stage = optimizer_state.get(ZERO_STAGE, 1)
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(ds_checkpoint.mp_rank_files[0], weights_only=True, map_location=torch.device('cpu'))
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
Expand Down Expand Up @@ -488,7 +488,7 @@ def main(args):

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
mp_sd = torch.load(mp_rank_file, weights_only=True, map_location=torch.device('cpu'))
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
ckpt_dict = torch.load(ckpt_file, weights_only=True)

if key == "step":
step = ckpt_dict
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
sd = torch.load(state_file, map_location=torch.device('cpu'))
sd = torch.load(state_file, weights_only=True, map_location=torch.device('cpu'))
for key in keys_to_ignore:
sd.pop(key, None)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,15 +460,15 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.sd = torch.load(checkpoint[0], weights_only=True, map_location='cpu')
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], weights_only=True, map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def populate_model_parameters(self) -> None:
buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)

buffer = torch.load(buffer_path)
buffer = torch.load(buffer_path, weights_only=True)
metadata = json.load(open(metadata_path, "r"))
metadata = ModelMetadata.parse_raw(metadata)

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")

for i in range(len(checkpoint)):
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), weights_only=True, map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
Expand All @@ -427,7 +427,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count)
]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
sds = [torch.load(ckpt_file, weights_only=True, map_location='cpu') for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand All @@ -447,7 +447,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
sds = [torch.load(ckpt_file, weights_only=True, map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand Down Expand Up @@ -613,7 +613,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
from safetensors.torch import load_file
sd = load_file(checkpoint)
else:
sd = torch.load(checkpoint, map_location='cpu')
sd = torch.load(checkpoint, weights_only=True, map_location='cpu')

policy = {}
if orig_class is not None:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=True)

self._load_global_state(optim_sd)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load(self, path: str, map_location=None):
if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, weights_only=True, map_location=map_location)
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save(self, state_dict, path: str):

def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, weights_only=True, map_location=map_location)
logger.info(f"[Torch] Loaded checkpoint from {path}.")
return partition

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
checkpoint = torch.load("saved.pth", weights_only=True)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
checkpoint = torch.load("saved.pth", weights_only=True)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,7 +2165,7 @@ def forward(self, input):
with deepspeed.zero.Init():
model = MyModel()

state_dict = torch.load(model_path, map_location="cpu")
state_dict = torch.load(model_path, weights_only=True, map_location="cpu")

def load(module: nn.Module, prefix=""):
# because zero3 puts placeholders in model params, this context
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2649,7 +2649,7 @@ def load_state_dict(self,
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
checkpoint = torch.load("saved.pth", weights_only=True)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
Expand Down Expand Up @@ -2691,7 +2691,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'

optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=True)
self._load_global_state_stage3(optim_sd)

key_list = ["fp32", "exp_avg", "exp_avg_sq"]
Expand Down Expand Up @@ -2749,7 +2749,7 @@ def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()

# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1)
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=True).view(-1)

# Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2347,7 +2347,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
checkpoint = torch.load("saved.pth", weights_only=True)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_model_state_files(checkpoint_dir):
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
state_dict = torch.load(file, weights_only=True, map_location=device)

if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
Expand Down Expand Up @@ -143,7 +143,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
state_dict = torch.load(f, weights_only=True, map_location=device)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
Expand Down Expand Up @@ -524,7 +524,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
loaded with ``torch.load(file, weights_only=True)`` + ``load_state_dict()`` and used for training without DeepSpeed.

Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f))
expert = torch.load(os.path.join(root, f), weights_only=True)
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
)

hidden_dim = 10
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=True)

ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
Expand Down
Loading
Loading