Skip to content

Commit

Permalink
2025-01-07 nightly release (6f4bfe2)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 7, 2025
1 parent be0f3db commit fdc60ae
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 44 deletions.
8 changes: 8 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from torch.distributed._tensor.placement_types import Placement
from torch.nn.modules.module import _addindent
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.global_settings import (
construct_sharded_tensor_from_metadata_enabled,
)
from torchrec.distributed.types import (
get_tensor_size_bytes,
ModuleSharder,
Expand Down Expand Up @@ -343,6 +346,11 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []

# option to construct ShardedTensor from metadata avoiding expensive all-gather
self._construct_sharded_tensor_from_metadata: bool = (
construct_sharded_tensor_from_metadata_enabled()
)

def prefetch(
self,
dist_input: KJTList,
Expand Down
51 changes: 42 additions & 9 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -81,6 +82,7 @@
optimizer_type_to_emb_opt_type,
)
from torchrec.modules.embedding_configs import (
data_type_to_dtype,
EmbeddingBagConfig,
EmbeddingTableConfig,
PoolingType,
Expand Down Expand Up @@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group

# create ShardedTensor from local shards and metadata avoding all_gather collective
if self._construct_sharded_tensor_from_metadata:
sharding_spec = none_throws(
self.module_sharding_plan[table_name].sharding_spec
)

tensor_properties = TensorProperties(
dtype=(
data_type_to_dtype(
self._table_name_to_config[table_name].data_type
)
),
)
)

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=sharding_spec.build_metadata(
tensor_sizes=self._name_to_table_size[table_name],
tensor_properties=tensor_properties,
),
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)
else:
# create ShardedTensor from local shards using all_gather collective
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingBagCollection,
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

# pyre-strict

import os

PROPOGATE_DEVICE: bool = False

TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = (
"TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA"
)


def set_propogate_device(val: bool) -> None:
global PROPOGATE_DEVICE
Expand All @@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None:
def get_propogate_device() -> bool:
global PROPOGATE_DEVICE
return PROPOGATE_DEVICE


def construct_sharded_tensor_from_metadata_enabled() -> bool:
return (
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
)
Loading

0 comments on commit fdc60ae

Please sign in to comment.