diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 06ad9f26e..84e033a31 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -65,6 +65,7 @@ QuantizedCommCodecs, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -938,7 +939,11 @@ def _initialize_torch_state(self) -> None: # noqa ShardedTensor._init_from_local_shards( local_shards, self._name_to_table_size[table_name], - process_group=self._env.process_group, + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), ) )