Skip to content

Commit

Permalink
Increase tensor creator coverage (#3684)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Jun 8, 2023
1 parent fc8e5c8 commit 046afce
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class ZeroParamStatus(Enum):
_orig_torch_zeros = torch.zeros
_orig_torch_ones = torch.ones
_orig_torch_full = torch.full
_orig_torch_arange = torch.arange
_orig_torch_eye = torch.eye


def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
Expand Down Expand Up @@ -424,11 +426,7 @@ def _init_subclass(cls, **kwargs):
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)

torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
self._add_tensor_creation_wrappers()

if self.mem_efficient_linear:
print_rank_0(
Expand Down Expand Up @@ -476,6 +474,24 @@ def _set_dtype(self, ds_config, dtype):
else:
self.dtype = dtype or torch.half

def _add_tensor_creation_wrappers(self):
torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)

def _remove_tensor_creation_wrappers(self):
torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full
torch.arange = _orig_torch_arange
torch.eye = _orig_torch_eye

def remove_wrappers(self):

def _disable_class(cls):
Expand All @@ -491,11 +507,7 @@ def _disable_class(cls):
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply

torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full
self._remove_tensor_creation_wrappers()

# un doing it here will undo it during training
# if self.mem_efficient_linear:
Expand Down

0 comments on commit 046afce

Please sign in to comment.