From 96333634bd96fab542d1cbfcc0aa657eb04b7740 Mon Sep 17 00:00:00 2001 From: Ziyang Date: Tue, 13 Aug 2024 02:04:53 +0800 Subject: [PATCH 1/5] Fix incorrect patch in zero.init --- .../runtime/zero/partition_parameters.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index f76bcf0eb781..99a65372639f 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -464,16 +464,22 @@ def wrapper(*args, **kwargs): return wrapper - def _enable_class_apply(cls): + def _enable_class_apply_backup(cls): cls._old_apply_of_skip_init_hook = cls._apply + + def _enable_class_apply_replace(cls): cls._apply = partition_after_empty_init(cls._apply) def _disable_class_apply(cls): cls._apply = cls._old_apply_of_skip_init_hook # add hooks for to_empty: apply_(empty_like) - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - _enable_class_apply(subclass) + all_subclasses = get_all_subclasses(torch.nn.modules.module.Module) + # split into two steps to address the inheritance problem + for subclass in all_subclasses: + _enable_class_apply_backup(subclass) + for subclass in all_subclasses: + _enable_class_apply_replace(subclass) # add a restore hook when exiting skip_init module.to_empty = post_wrapper_to_empty(module.to_empty) @@ -521,8 +527,10 @@ def wrapper(module, *args, **kwargs): return wrapper - def _enable_class(cls): + def _enable_class_backup(cls): cls._old_init = cls.__init__ + + def _enable_class_repalce(cls): cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): @@ -530,8 +538,12 @@ def _init_subclass(cls, **kwargs): cls.__init__ = partition_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module recursively - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - _enable_class(subclass) + all_subclasses = get_all_subclasses(torch.nn.modules.module.Module) + # split into two steps to address the inheritance problem + for subclass in all_subclasses: + _enable_class_backup(subclass) + for subclass in all_subclasses: + _enable_class_replace(subclass) # holding onto some methods so we can put them back the way they were in __exit__ torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ From 5f564b608094947c2b56c1fab251c074d364f9f3 Mon Sep 17 00:00:00 2001 From: Ziyang Date: Tue, 13 Aug 2024 02:22:46 +0800 Subject: [PATCH 2/5] Fix typo --- deepspeed/runtime/zero/partition_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 99a65372639f..09b3bc740802 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -530,7 +530,7 @@ def wrapper(module, *args, **kwargs): def _enable_class_backup(cls): cls._old_init = cls.__init__ - def _enable_class_repalce(cls): + def _enable_class_replace(cls): cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): From 348b29df2125fd1b1344cf891cdea7d0d212565f Mon Sep 17 00:00:00 2001 From: Ziyang Date: Tue, 13 Aug 2024 11:43:57 +0800 Subject: [PATCH 3/5] Better solution to handle _init_subclass as well --- .../runtime/zero/partition_parameters.py | 46 ++++++++----------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 09b3bc740802..8fb6b66b9352 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -447,7 +447,7 @@ def wrapper(module, *args, **kwargs): # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init self._post_init_method(_module) return _module - + wrapped._ds_has_wrapped = True return wrapper def post_wrapper_to_empty(f): @@ -464,22 +464,18 @@ def wrapper(*args, **kwargs): return wrapper - def _enable_class_apply_backup(cls): - cls._old_apply_of_skip_init_hook = cls._apply - - def _enable_class_apply_replace(cls): - cls._apply = partition_after_empty_init(cls._apply) + def _enable_class_apply(cls): + # avoid re-wrap + if not hasattr(cls._apply, '_ds_has_wrapped'): + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) def _disable_class_apply(cls): cls._apply = cls._old_apply_of_skip_init_hook # add hooks for to_empty: apply_(empty_like) - all_subclasses = get_all_subclasses(torch.nn.modules.module.Module) - # split into two steps to address the inheritance problem - for subclass in all_subclasses: - _enable_class_apply_backup(subclass) - for subclass in all_subclasses: - _enable_class_apply_replace(subclass) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class_apply(subclass) # add a restore hook when exiting skip_init module.to_empty = post_wrapper_to_empty(module.to_empty) @@ -524,26 +520,24 @@ def wrapper(module, *args, **kwargs): print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) if init_on_meta: self.skip_init_depth -= 1 - + wrapped._ds_has_wrapped = True return wrapper - def _enable_class_backup(cls): - cls._old_init = cls.__init__ - - def _enable_class_replace(cls): - cls.__init__ = partition_after(cls.__init__) + def _enable_class(cls): + # avoid re-wrap + if not hasattr(cls.__init__, '_ds_has_wrapped'): + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): - cls._old_init = cls.__init__ - cls.__init__ = partition_after(cls.__init__) + # avoid re-wrap + if not hasattr(cls.__init__, '_ds_has_wrapped'): + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module recursively - all_subclasses = get_all_subclasses(torch.nn.modules.module.Module) - # split into two steps to address the inheritance problem - for subclass in all_subclasses: - _enable_class_backup(subclass) - for subclass in all_subclasses: - _enable_class_replace(subclass) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class(subclass) # holding onto some methods so we can put them back the way they were in __exit__ torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ From bf845228bc8c59b6b368d01bf648bba276f1ce86 Mon Sep 17 00:00:00 2001 From: Ziyang Date: Tue, 13 Aug 2024 11:54:29 +0800 Subject: [PATCH 4/5] Fix typo --- deepspeed/runtime/zero/partition_parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 8fb6b66b9352..75871884f820 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -447,7 +447,7 @@ def wrapper(module, *args, **kwargs): # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init self._post_init_method(_module) return _module - wrapped._ds_has_wrapped = True + wrapper._ds_has_wrapped = True return wrapper def post_wrapper_to_empty(f): @@ -520,7 +520,7 @@ def wrapper(module, *args, **kwargs): print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) if init_on_meta: self.skip_init_depth -= 1 - wrapped._ds_has_wrapped = True + wrapper._ds_has_wrapped = True return wrapper def _enable_class(cls): From a4789201d2dcfc055bea1a29acc203996ec8eeec Mon Sep 17 00:00:00 2001 From: VeryLazyBoy Date: Wed, 14 Aug 2024 01:10:32 +0800 Subject: [PATCH 5/5] Fix pre-commit format errors --- deepspeed/runtime/zero/partition_parameters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 75871884f820..bb458c222700 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -447,6 +447,7 @@ def wrapper(module, *args, **kwargs): # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init self._post_init_method(_module) return _module + wrapper._ds_has_wrapped = True return wrapper @@ -520,6 +521,7 @@ def wrapper(module, *args, **kwargs): print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) if init_on_meta: self.skip_init_depth -= 1 + wrapper._ds_has_wrapped = True return wrapper