Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move sharding optimization flag to global_settings #2665

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from torch.distributed._tensor.placement_types import Placement
from torch.nn.modules.module import _addindent
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.global_settings import (
construct_sharded_tensor_from_metadata_enabled,
)
from torchrec.distributed.types import (
get_tensor_size_bytes,
ModuleSharder,
Expand Down Expand Up @@ -343,6 +346,11 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []

# option to construct ShardedTensor from metadata avoiding expensive all-gather
self._construct_sharded_tensor_from_metadata: bool = (
construct_sharded_tensor_from_metadata_enabled()
)

def prefetch(
self,
dist_input: KJTList,
Expand Down
51 changes: 42 additions & 9 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -81,6 +82,7 @@
optimizer_type_to_emb_opt_type,
)
from torchrec.modules.embedding_configs import (
data_type_to_dtype,
EmbeddingBagConfig,
EmbeddingTableConfig,
PoolingType,
Expand Down Expand Up @@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group

# create ShardedTensor from local shards and metadata avoding all_gather collective
if self._construct_sharded_tensor_from_metadata:
sharding_spec = none_throws(
self.module_sharding_plan[table_name].sharding_spec
)

tensor_properties = TensorProperties(
dtype=(
data_type_to_dtype(
self._table_name_to_config[table_name].data_type
)
),
)
)

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=sharding_spec.build_metadata(
tensor_sizes=self._name_to_table_size[table_name],
tensor_properties=tensor_properties,
),
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)
else:
# create ShardedTensor from local shards using all_gather collective
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingBagCollection,
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

# pyre-strict

import os

PROPOGATE_DEVICE: bool = False

TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = (
"TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA"
)


def set_propogate_device(val: bool) -> None:
global PROPOGATE_DEVICE
Expand All @@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None:
def get_propogate_device() -> bool:
global PROPOGATE_DEVICE
return PROPOGATE_DEVICE


def construct_sharded_tensor_from_metadata_enabled() -> bool:
return (
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
)
Loading
Loading