Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect patch in zero.init #5921

Closed
20 changes: 14 additions & 6 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def wrapper(module, *args, **kwargs):
self._post_init_method(_module)
return _module

wrapper._ds_has_wrapped = True
return wrapper

def post_wrapper_to_empty(f):
Expand All @@ -465,8 +466,10 @@ def wrapper(*args, **kwargs):
return wrapper

def _enable_class_apply(cls):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)
# avoid re-wrap
if not hasattr(cls._apply, '_ds_has_wrapped'):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook
Expand Down Expand Up @@ -519,15 +522,20 @@ def wrapper(module, *args, **kwargs):
if init_on_meta:
self.skip_init_depth -= 1

wrapper._ds_has_wrapped = True
return wrapper

def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# avoid re-wrap
if not hasattr(cls.__init__, '_ds_has_wrapped'):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# avoid re-wrap
if not hasattr(cls.__init__, '_ds_has_wrapped'):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
Expand Down
Loading