From d555266cf0e458d8ea21d26269f5b18b89c14e21 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 21 Aug 2024 01:51:54 +0000 Subject: [PATCH 1/2] fix zero init patch --- .../runtime/zero/partition_parameters.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index f76bcf0eb781..a5f3b9ddf11f 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -465,11 +465,13 @@ def wrapper(*args, **kwargs): return wrapper def _enable_class_apply(cls): - cls._old_apply_of_skip_init_hook = cls._apply - cls._apply = partition_after_empty_init(cls._apply) + if '_apply' in cls.__dict__: + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) def _disable_class_apply(cls): - cls._apply = cls._old_apply_of_skip_init_hook + if hasattr(cls, '_old_apply_of_skip_init_hook'): + cls._apply = cls._old_apply_of_skip_init_hook # add hooks for to_empty: apply_(empty_like) for subclass in get_all_subclasses(torch.nn.modules.module.Module): @@ -522,12 +524,14 @@ def wrapper(module, *args, **kwargs): return wrapper def _enable_class(cls): - cls._old_init = cls.__init__ - cls.__init__ = partition_after(cls.__init__) + if '__init__' in cls.__dict__: + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): - cls._old_init = cls.__init__ - cls.__init__ = partition_after(cls.__init__) + if '__init__' in cls.__dict__: + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module recursively for subclass in get_all_subclasses(torch.nn.modules.module.Module): @@ -567,7 +571,8 @@ def unpatch_init_and_builtins(self): if self.patched: def _disable_class(cls): - cls.__init__ = cls._old_init + if hasattr(cls, '_old_init'): + cls.__init__ = cls._old_init for subclass in get_all_subclasses(torch.nn.modules.module.Module): _disable_class(subclass) From f064ebbc09ec431d7b8a716a2ae8183b81dcc9a7 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 21 Aug 2024 20:46:01 +0000 Subject: [PATCH 2/2] patch the root class --- deepspeed/runtime/zero/partition_parameters.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index a5f3b9ddf11f..85a15ac7c2db 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -262,7 +262,7 @@ def new_tensor(cls, *args, **kwargs) -> Tensor: # https://stackoverflow.com/a/63851681/9201239 -def get_all_subclasses(cls): +def get_all_subclasses(cls, include_root=True): subclass_list = [] def recurse(cl): @@ -272,7 +272,10 @@ def recurse(cl): recurse(cls) - return set(subclass_list) + ret = set(subclass_list) + if include_root: + ret.add(cls) + return ret @instrument_w_nvtx