Skip to content

Commit

Permalink
Optimize performance of embeddings sharding
Browse files Browse the repository at this point in the history
Summary:
While working on TTFB it was observed that sharding of embededed bag is taking significant time and is one of the biggest contributors to TTFB especially on large jobs.
After strobelight data analysis it was clear that most of the time is spent on all_gather collective calls. Currently we construct sharded tensor one by one calling collective to exchange metadata which is not very efficient.  More optimal approach is letting all the ranks build their portion of metadata for all tensors and exchange it with single collective call, thus significantly reducing overhead and improve performance.

Testing on 256 ranks showed ~13x speed up.

Differential Revision: D65489998
  • Loading branch information
Boris Sarana authored and facebook-github-bot committed Nov 15, 2024
1 parent 9fb6b8e commit b3649c8
Showing 1 changed file with 110 additions and 67 deletions.
177 changes: 110 additions & 67 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._tensor import DTensor, Shard

from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
from torch.distributed._shard.sharded_tensor.utils import (
_flatten_tensor_size,
build_global_metadata,
build_metadata_from_local_shards,
)
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -98,13 +105,6 @@
pass


# OSS
try:
pass
except ImportError:
pass


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
return (
tensor
Expand Down Expand Up @@ -615,10 +615,9 @@ def __init__(
)
self._env = env
# output parameters as DTensor in state dict
self._output_dtensor: bool = (
fused_params.get("output_dtensor", False) if fused_params else False
self._output_dtensor: bool = fused_params and fused_params.get(
"output_dtensor", False
)

sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
module,
table_name_to_parameter_sharding,
Expand Down Expand Up @@ -818,11 +817,103 @@ def _pre_load_state_dict_hook(
lookup = lookup.module
lookup.purge()

def _construct_sharded_tensors_map(
self,
name_to_kernel: Dict[str, str],
) -> None:
# Collect metadata for all distributed tensors, do single collective call
# and merge them to avoid overhead doing it per tensor
rank = dist.get_rank()
world_size = dist.get_world_size()

rank_metadata = {
table_name: (
build_metadata_from_local_shards(
local_shards,
_flatten_tensor_size(self._name_to_table_size[table_name]),
rank,
none_throws(self._env.process_group),
)
)
for table_name, local_shards in self._model_parallel_name_to_local_shards.items()
if local_shards
}

gathered_metadata: List[Dict[str, ShardedTensorMetadata]] = [
{} for _ in range(world_size)
]

dist.all_gather_object(
gathered_metadata,
rank_metadata,
group=none_throws(self._env.process_group),
)

for table_name in self._model_parallel_name_to_local_shards.keys():
if not hasattr(self.embedding_bags[table_name], "weight"):
self.embedding_bags[table_name].register_parameter(
"weight", nn.Parameter(torch.empty(0))
)
if name_to_kernel[table_name] != EmbeddingComputeKernel.DENSE.value:
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=self._model_parallel_name_to_local_shards[table_name],
sharded_tensor_metadata=build_global_metadata(
[item.get(table_name, None) for item in gathered_metadata]
),
process_group=none_throws(self._env.process_group),
)
)

def _construct_dtensor_map(
self,
name_to_kernel: Dict[str, str],
) -> None:
for table_name in self._model_parallel_name_to_local_shards.keys():
shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name]
# for shards that don't exist on this rank, register with empty tensor
if not hasattr(self.embedding_bags[table_name], "weight"):
self.embedding_bags[table_name].register_parameter(
"weight", nn.Parameter(torch.empty(0))
)
if name_to_kernel[table_name] != EmbeddingComputeKernel.DENSE.value:
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

if shards_wrapper_map["local_tensors"]:
self._model_parallel_name_to_dtensor[table_name] = DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=shards_wrapper_map["local_tensors"],
local_offsets=shards_wrapper_map["local_offsets"],
),
device_mesh=self._env.device_mesh,
placements=shards_wrapper_map["placements"],
shape=shards_wrapper_map["global_size"],
stride=shards_wrapper_map["global_stride"],
run_check=False,
)
else:
# empty shard case
self._model_parallel_name_to_dtensor[table_name] = DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=[],
local_offsets=[],
),
device_mesh=self._env.device_mesh,
run_check=False,
)

def _initialize_torch_state(self) -> None: # noqa
"""
This provides consistency between this class and the EmbeddingBagCollection's
nn.Module API calls (state_dict, named_modules, etc)
"""

self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
for table_name in self._table_names:
self.embedding_bags[table_name] = nn.Module()
Expand Down Expand Up @@ -888,63 +979,15 @@ def _initialize_torch_state(self) -> None: # noqa
) in lookup.named_parameters_by_table():
self.embedding_bags[table_name].register_parameter("weight", tbe_slice)

for table_name in self._model_parallel_name_to_local_shards.keys():
local_shards = self._model_parallel_name_to_local_shards[table_name]
shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name]
# for shards that don't exist on this rank, register with empty tensor
if not hasattr(self.embedding_bags[table_name], "weight"):
self.embedding_bags[table_name].register_parameter(
"weight", nn.Parameter(torch.empty(0))
)
if (
self._model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]
if self._output_dtensor:
self._construct_dtensor_map(
name_to_kernel=self._model_parallel_name_to_compute_kernel
)
else:

if self._output_dtensor:
assert self._model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=shards_wrapper_map["local_tensors"],
local_offsets=shards_wrapper_map["local_offsets"],
),
device_mesh=self._env.device_mesh,
placements=shards_wrapper_map["placements"],
shape=shards_wrapper_map["global_size"],
stride=shards_wrapper_map["global_stride"],
run_check=False,
)
)
else:
# empty shard case
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=[],
local_offsets=[],
),
device_mesh=self._env.device_mesh,
run_check=False,
)
)
else:
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
)
)
self._construct_sharded_tensors_map(
name_to_kernel=self._model_parallel_name_to_compute_kernel
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingBagCollection,
Expand Down

0 comments on commit b3649c8

Please sign in to comment.