diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index e24a695d9..78f9ab9ef 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -29,7 +29,7 @@ from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import DTensor, Shard from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( @@ -68,7 +68,6 @@ ShardingEnv2D, ShardingType, ShardMetadata, - TensorProperties, ) from torchrec.distributed.utils import ( add_params_from_parameter_sharding, @@ -943,31 +942,11 @@ def _initialize_torch_state(self) -> None: # noqa # created ShardedTensors once in init, use in post_state_dict_hook # note: at this point kvstore backed tensors don't own valid snapshots, so no read # access is allowed on them. - sharding_spec = none_throws( - self.module_sharding_plan[table_name].sharding_spec - ) - metadata = sharding_spec.build_metadata( - tensor_sizes=self._name_to_table_size[table_name], - tensor_properties=( - TensorProperties( - dtype=local_shards[0].tensor.dtype, - layout=local_shards[0].tensor.layout, - requires_grad=local_shards[0].tensor.requires_grad, - ) - if local_shards - else TensorProperties() - ), - ) - self._model_parallel_name_to_sharded_tensor[table_name] = ( - ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=local_shards, - sharded_tensor_metadata=metadata, - process_group=( - self._env.sharding_pg - if isinstance(self._env, ShardingEnv2D) - else self._env.process_group - ), + ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=self._env.process_group, ) )