Skip to content

Commit

Permalink
fix lm head overriden issue, move it from checkpoint in-loop loading …
Browse files Browse the repository at this point in the history
…to out loop (#4206)

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
sywangyi and tjruwase authored Oct 5, 2023
1 parent 4294ea1 commit d72edb3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
7 changes: 0 additions & 7 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,6 @@ def load_module_recursive(module, prefix='', level=0):
level + 1)

load_module_recursive(r_module)
embedding_weight = None

for n, p in r_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and r_module.lm_head.weight.is_meta:
r_module.lm_head.weight = embedding_weight

for sd_ in sd:
del sd_
Expand Down
19 changes: 11 additions & 8 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):

return new_module

def set_lm_head(module):
embedding_weight = None
for n, p in module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
checkpoint = checkpoint_dict["checkpoints"]
Expand All @@ -309,6 +318,7 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
set_lm_head(replaced_module)
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down Expand Up @@ -386,6 +396,7 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
container=container_g)
sds = [None for _ in sds]
gc.collect()
set_lm_head(replaced_module)
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")

if config.save_mp_checkpoint_path is not None:
Expand Down Expand Up @@ -554,14 +565,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
"You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"

replaced_module, _ = _replace_module(model, policy, state_dict=sd)
if checkpoint is not None:
embedding_weight = None
for n, p in replaced_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and hasattr(replaced_module, "lm_head") and hasattr(
replaced_module.lm_head, "weight") and replaced_module.lm_head.weight.is_meta:
replaced_module.lm_head.weight = embedding_weight
return replaced_module


Expand Down

0 comments on commit d72edb3

Please sign in to comment.