diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index b575a3457..1c6377ddc 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -130,6 +130,9 @@ def __init__( def _create_lookup( config: GroupedEmbeddingConfig, ) -> BaseEmbedding: + for table in config.embedding_tables: + if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING: + self._need_prefetch = True if config.compute_kernel == EmbeddingComputeKernel.DENSE: return BatchedDenseEmbedding( config=config, @@ -149,6 +152,7 @@ def _create_lookup( super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() + self._need_prefetch: bool = False for config in grouped_configs: self._emb_modules.append(_create_lookup(config)) @@ -169,6 +173,34 @@ def _create_lookup( self.grouped_configs = grouped_configs + def prefetch( + self, + sparse_features: KeyedJaggedTensor, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + if not self._need_prefetch: + return + if len(self._emb_modules) > 0: + assert sparse_features is not None + features_by_group = sparse_features.split( + self._feature_splits, + ) + for emb_op, features in zip(self._emb_modules, features_by_group): + if ( + isinstance(emb_op.emb_module, SplitTableBatchedEmbeddingBagsCodegen) + and not emb_op.emb_module.prefetch_pipeline + ): + logging.error( + "Invalid setting on SplitTableBatchedEmbeddingBagsCodegen modules. prefetch_pipeline must be set to True.\n" + "If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n" + ) + if hasattr(emb_op.emb_module, "prefetch"): + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + forward_stream=forward_stream, + ) + def forward( self, sparse_features: KeyedJaggedTensor,