Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix patch for parameter partitioning in zero.Init() #6388

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def new_tensor(cls, *args, **kwargs) -> Tensor:


# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls):
def get_all_subclasses(cls, include_root=True):
subclass_list = []

def recurse(cl):
Expand All @@ -272,7 +272,10 @@ def recurse(cl):

recurse(cls)

return set(subclass_list)
ret = set(subclass_list)
if include_root:
ret.add(cls)
return ret


@instrument_w_nvtx
Expand Down Expand Up @@ -465,11 +468,13 @@ def wrapper(*args, **kwargs):
return wrapper

def _enable_class_apply(cls):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)
if '_apply' in cls.__dict__:
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook
if hasattr(cls, '_old_apply_of_skip_init_hook'):
cls._apply = cls._old_apply_of_skip_init_hook

# add hooks for to_empty: apply_(empty_like)
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
Expand Down Expand Up @@ -522,12 +527,14 @@ def wrapper(module, *args, **kwargs):
return wrapper

def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
if '__init__' in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
if '__init__' in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
Expand Down Expand Up @@ -567,7 +574,8 @@ def unpatch_init_and_builtins(self):
if self.patched:

def _disable_class(cls):
cls.__init__ = cls._old_init
if hasattr(cls, '_old_init'):
cls.__init__ = cls._old_init

for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class(subclass)
Expand Down
Loading