diff --git a/torchrec/distributed/embedding_dim_bucketer.py b/torchrec/distributed/embedding_dim_bucketer.py index 0033b2e72..4709c72c4 100644 --- a/torchrec/distributed/embedding_dim_bucketer.py +++ b/torchrec/distributed/embedding_dim_bucketer.py @@ -9,7 +9,10 @@ from enum import Enum, unique from typing import Dict, List -from torchrec.distributed.embedding_types import ShardedEmbeddingTable +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + ShardedEmbeddingTable, +) from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType @@ -153,3 +156,36 @@ def bucket(self, dim: int, dtype: DataType) -> int: def dim_in_bytes(self, dim: int, dtype: DataType) -> int: return dim * DATA_TYPE_NUM_BITS[dtype] // 8 + + +def should_do_dim_bucketing( + embedding_tables: List[ShardedEmbeddingTable], +) -> bool: + """ + When embedding memory offloading with caching is enabled, we prefer to + do dim bucketing for better utilization of cache space. Only applied to + "prefetch-sparse-dist" training pipeline. + + Currently using the compute kernel to deduct caching is enabled. + """ + table_pipeline_count = 0 + for table in embedding_tables: + if ( + table.fused_params is not None + and "prefetch_pipeline" in table.fused_params + and table.fused_params["prefetch_pipeline"] + ): + table_pipeline_count += 1 + + if table_pipeline_count > 0 and table_pipeline_count != len(embedding_tables): + AssertionError( + f"Only {table_pipeline_count} tables have prefetch-sparse-dist pipeline. It should be all {len(embedding_tables)} tables." + ) + + for table in embedding_tables: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + and table_pipeline_count + ): + return True + return False diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 0e5599c6c..857e4cb68 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import abc +import logging from dataclasses import dataclass, field from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union @@ -15,6 +16,11 @@ KJTAllToAllTensorsAwaitable, SplitsAllToAllAwaitable, ) +from torchrec.distributed.embedding_dim_bucketer import ( + EmbDimBucketer, + EmbDimBucketerPolicy, + should_do_dim_bucketing, +) from torchrec.distributed.embedding_types import ( BaseEmbeddingLookup, BaseGroupedFeatureProcessor, @@ -151,59 +157,82 @@ def _group_tables_per_rank( EmbeddingComputeKernel.QUANT, ] + emb_dim_bucketer_policy = ( + EmbDimBucketerPolicy.ALL_BUCKETS + if should_do_dim_bucketing(embedding_tables) + else EmbDimBucketerPolicy.SINGLE_BUCKET + ) + emb_dim_bucketer = EmbDimBucketer(embedding_tables, emb_dim_bucketer_policy) + logging.info(f"bucket count {emb_dim_bucketer.bucket_count()}") + for data_type in DataType: for pooling in PoolingType: # remove this when finishing migration for has_feature_processor in [False, True]: for fused_params_group in fused_params_groups: for compute_kernel in compute_kernels: - grouped_tables: List[ShardedEmbeddingTable] = [] - is_weighted = False - for table in embedding_tables: - compute_kernel_type = table.compute_kernel - is_weighted = table.is_weighted - if table.compute_kernel in [ - EmbeddingComputeKernel.FUSED_UVM, - EmbeddingComputeKernel.FUSED_UVM_CACHING, - ]: - compute_kernel_type = EmbeddingComputeKernel.FUSED - elif table.compute_kernel in [ - EmbeddingComputeKernel.QUANT_UVM, - EmbeddingComputeKernel.QUANT_UVM_CACHING, - ]: - compute_kernel_type = EmbeddingComputeKernel.QUANT - if ( - table.data_type == data_type - and table.pooling.value == pooling.value - and table.has_feature_processor - == has_feature_processor - and compute_kernel_type == compute_kernel - and table.fused_params == fused_params_group - ): - grouped_tables.append(table) - - if fused_params_group is None: - fused_params_group = {} - - if grouped_tables: - grouped_embedding_configs.append( - GroupedEmbeddingConfig( - data_type=data_type, - pooling=pooling, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - compute_kernel=compute_kernel, - embedding_tables=grouped_tables, - fused_params={ - k: v - for k, v in fused_params_group.items() - if k - not in [ - "_batch_key" - ] # drop '_batch_key' not a native fused param - }, + for dim_bucket in range(emb_dim_bucketer.bucket_count()): + grouped_tables: List[ShardedEmbeddingTable] = [] + is_weighted = False + for table in embedding_tables: + compute_kernel_type = table.compute_kernel + is_weighted = table.is_weighted + if table.compute_kernel in [ + EmbeddingComputeKernel.FUSED_UVM, + EmbeddingComputeKernel.FUSED_UVM_CACHING, + ]: + compute_kernel_type = ( + EmbeddingComputeKernel.FUSED + ) + elif table.compute_kernel in [ + EmbeddingComputeKernel.QUANT_UVM, + EmbeddingComputeKernel.QUANT_UVM_CACHING, + ]: + compute_kernel_type = ( + EmbeddingComputeKernel.QUANT + ) + + if ( + table.data_type == data_type + and table.pooling.value == pooling.value + and table.has_feature_processor + == has_feature_processor + and compute_kernel_type == compute_kernel + and table.fused_params == fused_params_group + and ( + emb_dim_bucketer.get_bucket( + table.embedding_dim, table.data_type + ) + == dim_bucket + ) + ): + grouped_tables.append(table) + + if fused_params_group is None: + fused_params_group = {} + + if grouped_tables: + logging.info( + f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}." + ) + grouped_embedding_configs.append( + GroupedEmbeddingConfig( + data_type=data_type, + pooling=pooling, + is_weighted=is_weighted, + has_feature_processor=has_feature_processor, + compute_kernel=compute_kernel, + embedding_tables=grouped_tables, + fused_params={ + k: v + for k, v in fused_params_group.items() + if k + not in [ + "_batch_key" + ] # drop '_batch_key' not a native fused param + }, + ) ) - ) return grouped_embedding_configs table_weightedness = [ diff --git a/torchrec/distributed/tests/test_emb_dim_bucketer.py b/torchrec/distributed/tests/test_emb_dim_bucketer.py index 34f3d44fb..f3c0f0a9b 100644 --- a/torchrec/distributed/tests/test_emb_dim_bucketer.py +++ b/torchrec/distributed/tests/test_emb_dim_bucketer.py @@ -13,9 +13,13 @@ from torchrec.distributed.embedding_dim_bucketer import ( EmbDimBucketer, EmbDimBucketerPolicy, + should_do_dim_bucketing, ) -from torchrec.distributed.embedding_types import ShardedEmbeddingTable +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + ShardedEmbeddingTable, +) from torchrec.modules.embedding_configs import DataType @@ -36,6 +40,7 @@ def gen_tables(self) -> Tuple[List[ShardedEmbeddingTable], int]: embedding_dim=buckets[i % num_buckets], num_embeddings=random.randint(100, 500000), data_type=DataType.FP16, + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, ) ) return embeddings, len(buckets) @@ -86,3 +91,7 @@ def test_all_bucket_policy(self) -> None: for i in range(emb_dim_bucketer.bucket_count()): self.assertTrue(i in emb_dim_bucketer.emb_dim_buckets.values()) + + def test_should_do_dim_bucketing(self) -> None: + embedding_tables, _ = self.gen_tables() + self.assertFalse(should_do_dim_bucketing(embedding_tables))