-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Jeff Rasley <[email protected]> Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Ammar Ahmad Awan <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Reza Yazdani <[email protected]>
- Loading branch information
1 parent
da652d0
commit 19b2587
Showing
98 changed files
with
924 additions
and
4,597 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Dict | ||
|
||
import torch | ||
|
||
CORE_PARAM = "_ds_core_param_key" | ||
|
||
STR_TO_DTYPE = { | ||
"torch.float32": torch.float32, | ||
"torch.float64": torch.float64, | ||
"torch.float16": torch.float16, | ||
"torch.bfloat16": torch.bfloat16, | ||
"torch.int64": torch.int64, | ||
"torch.int32": torch.int32, | ||
"torch.int16": torch.int16, | ||
"torch.int8": torch.int8, | ||
"torch.uint8": torch.uint8, | ||
"torch.bool": torch.bool, | ||
} | ||
|
||
|
||
class InferenceParameter(torch.Tensor): | ||
""" | ||
An extension of the torch.Tensor class to support our inference focused features. One important | ||
thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of | ||
torch.Tensor operations will not be InferenceParams. | ||
""" | ||
|
||
@staticmethod | ||
def __new__(cls, tensor, *args, **kwargs): | ||
new_tensor = super().__new__(cls, tensor, *args, **kwargs) | ||
if hasattr(tensor, "_aux_attrs"): | ||
setattr(new_tensor, "_aux_attrs", tensor.aux_attrs) | ||
return new_tensor | ||
|
||
def to(self, *args, **kwargs): | ||
new_tensor = super().to(*args, **kwargs) | ||
if hasattr(self, "_aux_attrs"): | ||
setattr(new_tensor, "_aux_attrs", self.aux_attrs) | ||
try: | ||
_ = torch.device(args[0]) | ||
for name, attr in new_tensor.aux_attrs.items(): | ||
new_attr = attr.to(*args, **kwargs) | ||
setattr(new_tensor, name, new_attr) | ||
new_tensor.aux_attrs[name] = new_attr | ||
except: | ||
pass | ||
|
||
return new_tensor | ||
|
||
@classmethod | ||
def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter': | ||
""" | ||
Create the inference parameter. | ||
""" | ||
param = InferenceParameter(core_param) | ||
setattr(param, "_aux_attrs", kwargs) | ||
|
||
for attr_name, attr in kwargs.items(): | ||
if hasattr(param, attr_name): | ||
raise ValueError(f"Attribute {attr_name} already exists on param.") | ||
|
||
if not isinstance(attr, torch.Tensor): | ||
raise ValueError(f"Attribute {attr_name} must be a tensor.") | ||
|
||
setattr(param, attr_name, attr) | ||
|
||
return param | ||
|
||
@classmethod | ||
def initialize_raw(self, **kwargs) -> 'InferenceParameter': | ||
""" | ||
All kwargs must be torch.Tensors and must include the core parameter. | ||
""" | ||
if CORE_PARAM not in kwargs: | ||
raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.") | ||
|
||
return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs) | ||
|
||
@property | ||
def aux_attrs(self) -> Dict[str, torch.Tensor]: | ||
""" | ||
Dictionary of auxiliary attributes. | ||
""" | ||
return self._aux_attrs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.