Skip to content

Commit

Permalink
Remove more unused classes
Browse files Browse the repository at this point in the history
Summary:
Removes three unused classes:

- `InferCPUGroupedEmbeddingsLookup`

Reviewed By: PaulZhang12

Differential Revision: D67153268
  • Loading branch information
sarckk authored and facebook-github-bot committed Dec 13, 2024
1 parent c4005c9 commit a5f44eb
Showing 1 changed file with 0 additions and 33 deletions.
33 changes: 0 additions & 33 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,36 +1101,3 @@ def get_tbes_to_register(
self,
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank)


class InferCPUGroupedEmbeddingsLookup(
InferGroupedLookupMixin,
BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]],
TBEToRegisterMixIn,
):
def __init__(
self,
grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]],
world_size: int,
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = []

device_type: str = "cpu" if device is None else device.type
for rank in range(world_size):
self._embedding_lookups_per_rank.append(
MetaInferGroupedEmbeddingsLookup(
grouped_configs=grouped_configs_per_rank[rank],
# syntax for torchscript
# pyre-fixme[20]: Argument `index` expected.
device=torch.device(type=device_type),
fused_params=fused_params,
)
)

def get_tbes_to_register(
self,
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank)

0 comments on commit a5f44eb

Please sign in to comment.