diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index f12db2ca3..058b07ef9 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -30,7 +30,7 @@ from torch.autograd.profiler import record_function from torchrec.tensor_types import UInt2Tensor, UInt4Tensor -from torchrec.types import ModuleNoCopyMixin +from torchrec.types import DataType, ModuleNoCopyMixin try: # For python 3.6 and below, GenericMeta will be used by @@ -108,29 +108,6 @@ def _tabulate( return "\n".join(rows) -# moved DataType here to avoid circular import -# TODO: organize types and dependencies -@unique -class DataType(Enum): - """ - Our fusion implementation supports only certain types of data - so it makes sense to retrict in a non-fused version as well. - """ - - FP32 = "FP32" - FP16 = "FP16" - BF16 = "BF16" - INT64 = "INT64" - INT32 = "INT32" - INT8 = "INT8" - UINT8 = "UINT8" - INT4 = "INT4" - INT2 = "INT2" - - def __str__(self) -> str: - return self.value - - class ShardingType(Enum): """ Well-known sharding types, used by inter-module optimizations. diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index 11d6dc65d..cf2c52019 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -14,7 +14,7 @@ import torch from fbgemm_gpu.split_embedding_configs import SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_training import PoolingMode -from torchrec.distributed.types import DataType +from torchrec.types import DataType @unique diff --git a/torchrec/types.py b/torchrec/types.py index 6788cbe2a..391cf2ca5 100644 --- a/torchrec/types.py +++ b/torchrec/types.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. from abc import abstractmethod +from enum import Enum, unique import torch from torch import nn @@ -35,3 +36,26 @@ class ModuleNoCopyMixin(CopyMixIn): def copy(self, device: torch.device) -> nn.Module: # pyre-ignore [7] return self + + +# moved DataType here to avoid circular import +# TODO: organize types and dependencies +@unique +class DataType(Enum): + """ + Our fusion implementation supports only certain types of data + so it makes sense to retrict in a non-fused version as well. + """ + + FP32 = "FP32" + FP16 = "FP16" + BF16 = "BF16" + INT64 = "INT64" + INT32 = "INT32" + INT8 = "INT8" + UINT8 = "UINT8" + INT4 = "INT4" + INT2 = "INT2" + + def __str__(self) -> str: + return self.value