Skip to content

Commit

Permalink
move DataType to avoid circular dependency (#1515)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1515

although updated in D46010148 for the purpose of removing circular dependency, it may still cause issues in certain cases as there is directory/module level circle. move `DataType` to a upper level file to fix it

Reviewed By: bigning

Differential Revision: D51311334

fbshipit-source-id: 5aa52dfc4f8af976ad735984f8bbc77136b8491b
  • Loading branch information
Bin Wen authored and facebook-github-bot committed Nov 15, 2023
1 parent 5a4daed commit c6f036a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
25 changes: 1 addition & 24 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from torch.autograd.profiler import record_function
from torchrec.tensor_types import UInt2Tensor, UInt4Tensor
from torchrec.types import ModuleNoCopyMixin
from torchrec.types import DataType, ModuleNoCopyMixin

try:
# For python 3.6 and below, GenericMeta will be used by
Expand Down Expand Up @@ -108,29 +108,6 @@ def _tabulate(
return "\n".join(rows)


# moved DataType here to avoid circular import
# TODO: organize types and dependencies
@unique
class DataType(Enum):
"""
Our fusion implementation supports only certain types of data
so it makes sense to retrict in a non-fused version as well.
"""

FP32 = "FP32"
FP16 = "FP16"
BF16 = "BF16"
INT64 = "INT64"
INT32 = "INT32"
INT8 = "INT8"
UINT8 = "UINT8"
INT4 = "INT4"
INT2 = "INT2"

def __str__(self) -> str:
return self.value


class ShardingType(Enum):
"""
Well-known sharding types, used by inter-module optimizations.
Expand Down
2 changes: 1 addition & 1 deletion torchrec/modules/embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_training import PoolingMode
from torchrec.distributed.types import DataType
from torchrec.types import DataType


@unique
Expand Down
24 changes: 24 additions & 0 deletions torchrec/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

from abc import abstractmethod
from enum import Enum, unique

import torch
from torch import nn
Expand Down Expand Up @@ -35,3 +36,26 @@ class ModuleNoCopyMixin(CopyMixIn):
def copy(self, device: torch.device) -> nn.Module:
# pyre-ignore [7]
return self


# moved DataType here to avoid circular import
# TODO: organize types and dependencies
@unique
class DataType(Enum):
"""
Our fusion implementation supports only certain types of data
so it makes sense to retrict in a non-fused version as well.
"""

FP32 = "FP32"
FP16 = "FP16"
BF16 = "BF16"
INT64 = "INT64"
INT32 = "INT32"
INT8 = "INT8"
UINT8 = "UINT8"
INT4 = "INT4"
INT2 = "INT2"

def __str__(self) -> str:
return self.value

0 comments on commit c6f036a

Please sign in to comment.