Skip to content

Commit

Permalink
mv DeepSpeedEngine param_names dict init post _configure_distributed_…
Browse files Browse the repository at this point in the history
…model

In some backends, when params are being moved from host to device, they
might changed their python object id(), which uses a the key in the
param_names dictionary. in such case this dict might become invalid.
  • Loading branch information
nelyahu committed Dec 12, 2023
1 parent b186816 commit 3eb90d7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,6 @@ def __init__(
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand All @@ -261,6 +258,9 @@ def __init__(
# Configure distributed model
self._configure_distributed_model(model)

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

self._get_model_parameters()

see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
Expand Down

0 comments on commit 3eb90d7

Please sign in to comment.