Skip to content

Commit

Permalink
Reland of D65489998 Optimize sharding performance of embeddings"
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/FBGEMM#3549

X-link: facebookresearch/FBGEMM#634

This diff is a reland of D65489998 after backout in D66800554.

Reviewed By: iamzainhuda

Differential Revision: D66828907
  • Loading branch information
Boris Sarana authored and facebook-github-bot committed Jan 6, 2025
1 parent 504642a commit 2c75b6e
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 43 deletions.
7 changes: 7 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import abc
import copy
import os
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -343,6 +344,12 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []

# option to construct ShardedTensor from metadata avoiding expensive all-gather
self._construct_sharded_tensor_from_metadata: bool = (
os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0")
== "1"
)

def prefetch(
self,
dist_input: KJTList,
Expand Down
51 changes: 42 additions & 9 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -81,6 +82,7 @@
optimizer_type_to_emb_opt_type,
)
from torchrec.modules.embedding_configs import (
data_type_to_dtype,
EmbeddingBagConfig,
EmbeddingTableConfig,
PoolingType,
Expand Down Expand Up @@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group

# create ShardedTensor from local shards and metadata avoding all_gather collective
if self._construct_sharded_tensor_from_metadata:
sharding_spec = none_throws(
self.module_sharding_plan[table_name].sharding_spec
)

tensor_properties = TensorProperties(
dtype=(
data_type_to_dtype(
self._table_name_to_config[table_name].data_type
)
),
)
)

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=sharding_spec.build_metadata(
tensor_sizes=self._name_to_table_size[table_name],
tensor_properties=tensor_properties,
),
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)
else:
# create ShardedTensor from local shards using all_gather collective
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingBagCollection,
Expand Down
Loading

0 comments on commit 2c75b6e

Please sign in to comment.