Skip to content

Commit

Permalink
Merge branch 'master' into fix_z3
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Nov 19, 2024
2 parents 8b30987 + 2e0c39b commit 7834887
Show file tree
Hide file tree
Showing 19 changed files with 50 additions and 40 deletions.
6 changes: 6 additions & 0 deletions csrc/xpu/adam/multi_tensor_apply.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class multi_tensor_apply_kernel {
std::tuple<ArgTypes...> args;
};

// to make sure multi_tensor_apply_kernel can be used in sycl::buffer
namespace sycl {
template <typename T, typename U, typename... ArgTypes>
struct is_device_copyable<multi_tensor_apply_kernel<T, U, ArgTypes...>> : std::true_type {};
} // namespace sycl

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
Expand Down
17 changes: 10 additions & 7 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')

def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

Expand All @@ -137,14 +137,17 @@ def get_final_norm_layer_id(self):

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)

return self.global_state[ITERATION_KEY]

def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd_list = [
torch.load(fname, map_location=torch.device('cpu'), weights_only=False)
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list)
return sd

Expand All @@ -154,7 +157,7 @@ def get_embedding_files(self, tp_index: int) -> list:

def _get_checkpoint_value(self, key):
if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[key] = sd.get(key, None)

return self.global_state[key]
Expand All @@ -169,7 +172,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]

merged_sd = None
for sd in sd_list:
Expand All @@ -185,7 +188,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
Expand All @@ -196,7 +199,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list:

def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False)
return sd

def get_final_norm_files(self, tp_index: int) -> list:
Expand Down
14 changes: 7 additions & 7 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):


def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)

flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
Expand Down Expand Up @@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]
shards = [torch.load(p, weights_only=False) for p in paths]

if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
Expand Down Expand Up @@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size):


def _parse_model_states_stage3(files):
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES]


def _save_optimizer_state(args, ds_checkpoint):
Expand All @@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint):


def _save_optimizer_state_stage3(args, optim_files):
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
output_sd = sd[OPTIMIZER_STATE_DICT]
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
Expand All @@ -446,15 +446,15 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern):


def _get_zero_stage(optim_files):
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
zero_stage = optimizer_state.get(ZERO_STAGE, 1)
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
Expand Down Expand Up @@ -488,7 +488,7 @@ def main(args):

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False)
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
ckpt_dict = torch.load(ckpt_file, weights_only=False)

if key == "step":
step = ckpt_dict
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
sd = torch.load(state_file, map_location=torch.device('cpu'))
sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False)
for key in keys_to_ignore:
sd.pop(key, None)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,15 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False)
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False)
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool:
else:
model_param_json_fname = "pytorch_model.bin.index.json"
model_file_fname = "pytorch_model.bin"
self._checkpoint_load_fn = partial(torch.load, map_location="cpu")
self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False)

model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def populate_model_parameters(self) -> None:
buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)

buffer = torch.load(buffer_path)
buffer = torch.load(buffer_path, weights_only=False)
metadata = json.load(open(metadata_path, "r"))
metadata = ModelMetadata.parse_raw(metadata)

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")

for i in range(len(checkpoint)):
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
Expand All @@ -437,7 +437,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count)
]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand All @@ -457,7 +457,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
Expand Down Expand Up @@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
from safetensors.torch import load_file
sd = load_file(checkpoint)
else:
sd = torch.load(checkpoint, map_location='cpu')
sd = torch.load(checkpoint, map_location='cpu', weights_only=False)

policy = {}
if orig_class is not None:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=False)

self._load_global_state(optim_sd)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load(self, path: str, map_location=None):
if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save(self, state_dict, path: str):

def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location)
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Torch] Loaded checkpoint from {path}.")
return partition

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'

optim_sd = torch.load(optim_state_path)
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)

key_list = ["fp32", "exp_avg", "exp_avg_sq"]
Expand Down Expand Up @@ -2799,7 +2799,7 @@ def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()

# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1)
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)

# Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group)
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_model_state_files(checkpoint_dir):
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
state_dict = torch.load(file, map_location=device, weights_only=False)

if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
Expand Down Expand Up @@ -149,7 +149,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in tqdm(files, desc='Loading checkpoint shards'):
state_dict = torch.load(f, map_location=device, mmap=True)
state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f))
expert = torch.load(os.path.join(root, f), weights_only=False)
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
)

hidden_dim = 10
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)

ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l
model.load_checkpoint(tmpdir, load_optimizer_states=load_optim)

if load_optim:
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file))
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False)
curr_sd = model.optimizer.optimizer.state_dict()
compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys)

Expand Down Expand Up @@ -523,7 +523,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
all_ckpt_folder = os.path.join(tmpdir, 'all_params')
ds_engine.save_checkpoint(all_ckpt_folder)
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
loaded_all_param_model = torch.load(all_params_ckpt_file)['module']
loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module']
all_param_names = set([n for n, p in model.named_parameters()])
assert set(loaded_all_param_model.keys()) == all_param_names

Expand All @@ -536,7 +536,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
# Excluding frozen parameters should reduce checkpoint size
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)

loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module']
loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module']
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage):

custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank(
os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00')
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module']
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module']
loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys())

custom_state_dict_param_names = set([k for k, v in model.state_dict().items()])
Expand Down Expand Up @@ -618,7 +618,8 @@ def test_save_tensor_clone(self, tmpdir, zero_stage, use_cpu_device):
clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt')
torch.save(clone_state_dict, clone_ckpt_file)

compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file))
compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False),
torch.load(clone_ckpt_file, weights_only=False))


class TestZeRONonDistributed(DistributedTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir):
test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
if dist.get_rank() == 0:
load_path = os.path.join(class_tmpdir, "output.pt")
baseline = torch.load(load_path)
baseline = torch.load(load_path, weights_only=False)
test = test.cpu()
assert torch.allclose(
baseline, test,
Expand Down
Loading

0 comments on commit 7834887

Please sign in to comment.