Skip to content

Commit

Permalink
Fix incorrect patch in zero.init
Browse files Browse the repository at this point in the history
  • Loading branch information
VeryLazyBoy authored Aug 12, 2024
1 parent ffe0af2 commit 9633363
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,22 @@ def wrapper(*args, **kwargs):

return wrapper

def _enable_class_apply(cls):
def _enable_class_apply_backup(cls):
cls._old_apply_of_skip_init_hook = cls._apply

def _enable_class_apply_replace(cls):
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook

# add hooks for to_empty: apply_(empty_like)
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class_apply(subclass)
all_subclasses = get_all_subclasses(torch.nn.modules.module.Module)
# split into two steps to address the inheritance problem
for subclass in all_subclasses:
_enable_class_apply_backup(subclass)
for subclass in all_subclasses:
_enable_class_apply_replace(subclass)

# add a restore hook when exiting skip_init
module.to_empty = post_wrapper_to_empty(module.to_empty)
Expand Down Expand Up @@ -521,17 +527,23 @@ def wrapper(module, *args, **kwargs):

return wrapper

def _enable_class(cls):
def _enable_class_backup(cls):
cls._old_init = cls.__init__

def _enable_class_repalce(cls):
cls.__init__ = partition_after(cls.__init__)

def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class(subclass)
all_subclasses = get_all_subclasses(torch.nn.modules.module.Module)
# split into two steps to address the inheritance problem
for subclass in all_subclasses:
_enable_class_backup(subclass)
for subclass in all_subclasses:
_enable_class_replace(subclass)

# holding onto some methods so we can put them back the way they were in __exit__
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
Expand Down

0 comments on commit 9633363

Please sign in to comment.