From 5f8a495e306b0fc46e9b27bc16a3fab90f90c0cd Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 8 Aug 2024 14:08:37 -0700 Subject: [PATCH] Support prefetching for SSD TBE lookup (#2275) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2275 Currently, we cannot use prefetch pipeline with SSD-based TBE. This diff adds the requires changes in torchrec code to support this. Reviewed By: chrisxcai Differential Revision: D60838580 fbshipit-source-id: 71c837554e21651656a77e8e01b36c95d23d135f --- torchrec/distributed/embedding_lookup.py | 55 ++++++++++++++++-------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index f431f8bdf..d5812e7af 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -20,6 +20,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( SplitTableBatchedEmbeddingBagsCodegen, ) +from fbgemm_gpu.tbe.ssd.training import SSDTableBatchedEmbeddingBags from torch import nn from torch.autograd.function import FunctionCtx @@ -182,7 +183,10 @@ def _create_lookup( config: GroupedEmbeddingConfig, ) -> BaseEmbedding: for table in config.embedding_tables: - if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + ): self._need_prefetch = True if config.compute_kernel == EmbeddingComputeKernel.DENSE: return BatchedDenseEmbedding( @@ -254,11 +258,18 @@ def prefetch( "If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n" ) if hasattr(emb_op.emb_module, "prefetch"): - emb_op.emb_module.prefetch( - indices=features.values(), - offsets=features.offsets(), - forward_stream=forward_stream, - ) + if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags): + # only takes indices and offsets + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + ) + else: + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + forward_stream=forward_stream, + ) def forward( self, @@ -455,7 +466,10 @@ def prefetch( ) -> None: def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: for table in config.embedding_tables: - if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + ): return True return False @@ -476,16 +490,23 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: "If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n" ) if hasattr(emb_op.emb_module, "prefetch"): - emb_op.emb_module.prefetch( - indices=features.values(), - offsets=features.offsets(), - forward_stream=forward_stream, - batch_size_per_feature_per_rank=( - features.stride_per_key_per_rank() - if features.variable_stride_per_key() - else None - ), - ) + if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags): + # only takes indices and offsets + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + ) + else: + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + forward_stream=forward_stream, + batch_size_per_feature_per_rank=( + features.stride_per_key_per_rank() + if features.variable_stride_per_key() + else None + ), + ) def _merge_variable_batch_embeddings( self, embeddings: List[torch.Tensor], splits: List[List[int]]