From b3649c8b8b6fd5eeee4ae75579f5fcb7b2b955a0 Mon Sep 17 00:00:00 2001 From: Boris Sarana Date: Fri, 15 Nov 2024 09:22:28 -0800 Subject: [PATCH] Optimize performance of embeddings sharding Summary: While working on TTFB it was observed that sharding of embededed bag is taking significant time and is one of the biggest contributors to TTFB especially on large jobs. After strobelight data analysis it was clear that most of the time is spent on all_gather collective calls. Currently we construct sharded tensor one by one calling collective to exchange metadata which is not very efficient. More optimal approach is letting all the ranks build their portion of metadata for all tensors and exchange it with single collective call, thus significantly reducing overhead and improve performance. Testing on 256 ranks showed ~13x speed up. Differential Revision: D65489998 --- torchrec/distributed/embeddingbag.py | 177 +++++++++++++++++---------- 1 file changed, 110 insertions(+), 67 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index c737df185..e01edb176 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -29,7 +29,14 @@ from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function -from torch.distributed._tensor import DTensor, Shard + +from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata +from torch.distributed._shard.sharded_tensor.utils import ( + _flatten_tensor_size, + build_global_metadata, + build_metadata_from_local_shards, +) +from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( @@ -98,13 +105,6 @@ pass -# OSS -try: - pass -except ImportError: - pass - - def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( tensor @@ -615,10 +615,9 @@ def __init__( ) self._env = env # output parameters as DTensor in state dict - self._output_dtensor: bool = ( - fused_params.get("output_dtensor", False) if fused_params else False + self._output_dtensor: bool = fused_params and fused_params.get( + "output_dtensor", False ) - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( module, table_name_to_parameter_sharding, @@ -818,11 +817,103 @@ def _pre_load_state_dict_hook( lookup = lookup.module lookup.purge() + def _construct_sharded_tensors_map( + self, + name_to_kernel: Dict[str, str], + ) -> None: + # Collect metadata for all distributed tensors, do single collective call + # and merge them to avoid overhead doing it per tensor + rank = dist.get_rank() + world_size = dist.get_world_size() + + rank_metadata = { + table_name: ( + build_metadata_from_local_shards( + local_shards, + _flatten_tensor_size(self._name_to_table_size[table_name]), + rank, + none_throws(self._env.process_group), + ) + ) + for table_name, local_shards in self._model_parallel_name_to_local_shards.items() + if local_shards + } + + gathered_metadata: List[Dict[str, ShardedTensorMetadata]] = [ + {} for _ in range(world_size) + ] + + dist.all_gather_object( + gathered_metadata, + rank_metadata, + group=none_throws(self._env.process_group), + ) + + for table_name in self._model_parallel_name_to_local_shards.keys(): + if not hasattr(self.embedding_bags[table_name], "weight"): + self.embedding_bags[table_name].register_parameter( + "weight", nn.Parameter(torch.empty(0)) + ) + if name_to_kernel[table_name] != EmbeddingComputeKernel.DENSE.value: + self.embedding_bags[table_name].weight._in_backward_optimizers = [ + EmptyFusedOptimizer() + ] + + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=self._model_parallel_name_to_local_shards[table_name], + sharded_tensor_metadata=build_global_metadata( + [item.get(table_name, None) for item in gathered_metadata] + ), + process_group=none_throws(self._env.process_group), + ) + ) + + def _construct_dtensor_map( + self, + name_to_kernel: Dict[str, str], + ) -> None: + for table_name in self._model_parallel_name_to_local_shards.keys(): + shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name] + # for shards that don't exist on this rank, register with empty tensor + if not hasattr(self.embedding_bags[table_name], "weight"): + self.embedding_bags[table_name].register_parameter( + "weight", nn.Parameter(torch.empty(0)) + ) + if name_to_kernel[table_name] != EmbeddingComputeKernel.DENSE.value: + self.embedding_bags[table_name].weight._in_backward_optimizers = [ + EmptyFusedOptimizer() + ] + + if shards_wrapper_map["local_tensors"]: + self._model_parallel_name_to_dtensor[table_name] = DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=shards_wrapper_map["local_tensors"], + local_offsets=shards_wrapper_map["local_offsets"], + ), + device_mesh=self._env.device_mesh, + placements=shards_wrapper_map["placements"], + shape=shards_wrapper_map["global_size"], + stride=shards_wrapper_map["global_stride"], + run_check=False, + ) + else: + # empty shard case + self._model_parallel_name_to_dtensor[table_name] = DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[], + local_offsets=[], + ), + device_mesh=self._env.device_mesh, + run_check=False, + ) + def _initialize_torch_state(self) -> None: # noqa """ This provides consistency between this class and the EmbeddingBagCollection's nn.Module API calls (state_dict, named_modules, etc) """ + self.embedding_bags: nn.ModuleDict = nn.ModuleDict() for table_name in self._table_names: self.embedding_bags[table_name] = nn.Module() @@ -888,63 +979,15 @@ def _initialize_torch_state(self) -> None: # noqa ) in lookup.named_parameters_by_table(): self.embedding_bags[table_name].register_parameter("weight", tbe_slice) - for table_name in self._model_parallel_name_to_local_shards.keys(): - local_shards = self._model_parallel_name_to_local_shards[table_name] - shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name] - # for shards that don't exist on this rank, register with empty tensor - if not hasattr(self.embedding_bags[table_name], "weight"): - self.embedding_bags[table_name].register_parameter( - "weight", nn.Parameter(torch.empty(0)) - ) - if ( - self._model_parallel_name_to_compute_kernel[table_name] - != EmbeddingComputeKernel.DENSE.value - ): - self.embedding_bags[table_name].weight._in_backward_optimizers = [ - EmptyFusedOptimizer() - ] + if self._output_dtensor: + self._construct_dtensor_map( + name_to_kernel=self._model_parallel_name_to_compute_kernel + ) + else: - if self._output_dtensor: - assert self._model_parallel_name_to_compute_kernel[table_name] not in { - EmbeddingComputeKernel.KEY_VALUE.value - } - if shards_wrapper_map["local_tensors"]: - self._model_parallel_name_to_dtensor[table_name] = ( - DTensor.from_local( - local_tensor=LocalShardsWrapper( - local_shards=shards_wrapper_map["local_tensors"], - local_offsets=shards_wrapper_map["local_offsets"], - ), - device_mesh=self._env.device_mesh, - placements=shards_wrapper_map["placements"], - shape=shards_wrapper_map["global_size"], - stride=shards_wrapper_map["global_stride"], - run_check=False, - ) - ) - else: - # empty shard case - self._model_parallel_name_to_dtensor[table_name] = ( - DTensor.from_local( - local_tensor=LocalShardsWrapper( - local_shards=[], - local_offsets=[], - ), - device_mesh=self._env.device_mesh, - run_check=False, - ) - ) - else: - # created ShardedTensors once in init, use in post_state_dict_hook - # note: at this point kvstore backed tensors don't own valid snapshots, so no read - # access is allowed on them. - self._model_parallel_name_to_sharded_tensor[table_name] = ( - ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=self._env.process_group, - ) - ) + self._construct_sharded_tensors_map( + name_to_kernel=self._model_parallel_name_to_compute_kernel + ) def extract_sharded_kvtensors( module: ShardedEmbeddingBagCollection,