diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 800b37185..e631f77a3 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -9,6 +9,7 @@ import abc import copy +import os from dataclasses import dataclass from enum import Enum, unique from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union @@ -343,6 +344,12 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] + # option to construct ShardedTensor from metadata avoiding expensive all-gather + self._construct_sharded_tensor_from_metadata: bool = ( + os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0") + == "1" + ) + def prefetch( self, dist_input: KJTList, diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index f2079a50c..8cfd16ae9 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -29,6 +29,7 @@ from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function +from torch.distributed._shard.sharded_tensor import TensorProperties from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel @@ -81,6 +82,7 @@ optimizer_type_to_emb_opt_type, ) from torchrec.modules.embedding_configs import ( + data_type_to_dtype, EmbeddingBagConfig, EmbeddingTableConfig, PoolingType, @@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa # created ShardedTensors once in init, use in post_state_dict_hook # note: at this point kvstore backed tensors don't own valid snapshots, so no read # access is allowed on them. - self._model_parallel_name_to_sharded_tensor[table_name] = ( - ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=( - self._env.sharding_pg - if isinstance(self._env, ShardingEnv2D) - else self._env.process_group + + # create ShardedTensor from local shards and metadata avoding all_gather collective + if self._construct_sharded_tensor_from_metadata: + sharding_spec = none_throws( + self.module_sharding_plan[table_name].sharding_spec + ) + + tensor_properties = TensorProperties( + dtype=( + data_type_to_dtype( + self._table_name_to_config[table_name].data_type + ) ), ) - ) + + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=sharding_spec.build_metadata( + tensor_sizes=self._name_to_table_size[table_name], + tensor_properties=tensor_properties, + ), + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), + ) + ) + else: + # create ShardedTensor from local shards using all_gather collective + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), + ) + ) def extract_sharded_kvtensors( module: ShardedEmbeddingBagCollection, diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 1ba371e21..879e3a3c7 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -28,6 +28,7 @@ from torchrec.distributed.types import ModuleSharder, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig, PoolingType from torchrec.test_utils import seed_and_log, skip_if_asan_class +from torchrec.types import DataType class ModelParallelTestShared(MultiProcessTestBase): @@ -35,27 +36,48 @@ class ModelParallelTestShared(MultiProcessTestBase): def setUp(self, backend: str = "nccl") -> None: super().setUp() - num_features = 4 - num_weighted_features = 2 - shared_features = 2 + self.num_features = 4 + self.num_weighted_features = 2 + self.num_shared_features = 2 + self.tables = [] + self.mean_tables = [] + self.weighted_tables = [] + self.embedding_groups = {} + self.shared_features = [] + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + + def _build_tables_and_groups( + self, + data_type: DataType = DataType.FP32, + ) -> None: self.tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, embedding_dim=(i + 2) * 8, name="table_" + str(i), feature_names=["feature_" + str(i)], + data_type=data_type, ) - for i in range(num_features) + for i in range(self.num_features) ] shared_features_tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, embedding_dim=(i + 2) * 8, - name="table_" + str(i + num_features), + name="table_" + str(i + self.num_features), feature_names=["feature_" + str(i)], + data_type=data_type, ) - for i in range(shared_features) + for i in range(self.num_shared_features) ] self.tables += shared_features_tables @@ -66,19 +88,21 @@ def setUp(self, backend: str = "nccl") -> None: name="table_" + str(i), feature_names=["feature_" + str(i)], pooling=PoolingType.MEAN, + data_type=data_type, ) - for i in range(num_features) + for i in range(self.num_features) ] shared_features_tables_mean = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, embedding_dim=(i + 2) * 8, - name="table_" + str(i + num_features), + name="table_" + str(i + self.num_features), feature_names=["feature_" + str(i)], pooling=PoolingType.MEAN, + data_type=data_type, ) - for i in range(shared_features) + for i in range(self.num_shared_features) ] self.mean_tables += shared_features_tables_mean @@ -88,11 +112,11 @@ def setUp(self, backend: str = "nccl") -> None: embedding_dim=(i + 2) * 4, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], + data_type=data_type, ) - for i in range(num_weighted_features) + for i in range(self.num_weighted_features) ] - - self.shared_features = [f"feature_{i}" for i in range(shared_features)] + self.shared_features = [f"feature_{i}" for i in range(self.num_shared_features)] self.embedding_groups = { "group_0": [ ( @@ -104,14 +128,6 @@ def setUp(self, backend: str = "nccl") -> None: for feature in table.feature_names ] } - self.backend = backend - if torch.cuda.is_available(): - self.device = torch.device("cuda") - else: - self.device = torch.device("cpu") - - if self.backend == "nccl" and self.device == torch.device("cpu"): - self.skipTest("NCCL not supported on CPUs.") def _test_sharding( self, @@ -132,7 +148,9 @@ def _test_sharding( has_weighted_tables: bool = True, global_constant_batch: bool = False, pooling: PoolingType = PoolingType.SUM, + data_type: DataType = DataType.FP32, ) -> None: + self._build_tables_and_groups(data_type=data_type) self._run_multi_process_test( callable=sharding_single_rank_test, world_size=world_size, @@ -198,6 +216,7 @@ def setUp(self, backend: str = "nccl") -> None: ), variable_batch_size=st.booleans(), pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) def test_sharding_rw( @@ -210,6 +229,7 @@ def test_sharding_rw( ], variable_batch_size: bool, pooling: PoolingType, + data_type: DataType, ) -> None: if self.backend == "gloo": self.skipTest( @@ -240,6 +260,7 @@ def test_sharding_rw( apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, pooling=pooling, + data_type=data_type, ) # pyre-fixme[56] @@ -252,6 +273,7 @@ def test_sharding_rw( ), kernel_type=st.just(EmbeddingComputeKernel.DENSE.value), apply_optimizer_in_backward_config=st.just(None), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), # TODO - need to enable optimizer overlapped behavior for data_parallel tables ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) @@ -262,6 +284,7 @@ def test_sharding_dp( apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], + data_type: DataType, ) -> None: sharding_type = ShardingType.DATA_PARALLEL.value self._test_sharding( @@ -271,6 +294,7 @@ def test_sharding_dp( ], backend=self.backend, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + data_type=data_type, ) # pyre-fixme[56] @@ -306,6 +330,7 @@ def test_sharding_dp( ] ), variable_batch_size=st.booleans(), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_cw( @@ -317,6 +342,7 @@ def test_sharding_cw( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, + data_type: DataType, ) -> None: if ( self.device == torch.device("cpu") @@ -348,6 +374,7 @@ def test_sharding_cw( }, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + data_type=data_type, ) # pyre-fixme[56] @@ -383,6 +410,7 @@ def test_sharding_cw( ] ), variable_batch_size=st.booleans(), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_twcw( @@ -394,6 +422,7 @@ def test_sharding_twcw( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, + data_type: DataType, ) -> None: if ( self.device == torch.device("cpu") @@ -425,6 +454,7 @@ def test_sharding_twcw( }, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + data_type=data_type, ) # pyre-fixme[56] @@ -461,6 +491,7 @@ def test_sharding_twcw( ] ), variable_batch_size=st.booleans(), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_tw( @@ -472,6 +503,7 @@ def test_sharding_tw( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, + data_type: DataType, ) -> None: if ( self.device == torch.device("cpu") @@ -499,6 +531,7 @@ def test_sharding_tw( qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + data_type=data_type, ) @unittest.skipIf( @@ -540,6 +573,7 @@ def test_sharding_tw( ), variable_batch_size=st.booleans(), pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) def test_sharding_twrw( @@ -552,6 +586,7 @@ def test_sharding_twrw( ], variable_batch_size: bool, pooling: PoolingType, + data_type: DataType, ) -> None: if self.backend == "gloo": self.skipTest( @@ -597,6 +632,7 @@ def test_sharding_twrw( ), global_constant_batch=st.booleans(), pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) def test_sharding_variable_batch( @@ -604,6 +640,7 @@ def test_sharding_variable_batch( sharding_type: str, global_constant_batch: bool, pooling: PoolingType, + data_type: DataType, ) -> None: if self.backend == "gloo": # error is from FBGEMM, it says CPU even if we are on GPU. @@ -634,6 +671,7 @@ def test_sharding_variable_batch( has_weighted_tables=False, global_constant_batch=global_constant_batch, pooling=pooling, + data_type=data_type, ) @unittest.skipIf( @@ -641,9 +679,14 @@ def test_sharding_variable_batch( "Not enough GPUs, this test requires at least two GPUs", ) # pyre-fixme[56] - @given(sharding_type=st.just(ShardingType.COLUMN_WISE.value)) + @given( + sharding_type=st.just(ShardingType.COLUMN_WISE.value), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) - def test_sharding_multiple_kernels(self, sharding_type: str) -> None: + def test_sharding_multiple_kernels( + self, sharding_type: str, data_type: DataType + ) -> None: if self.backend == "gloo": self.skipTest("ProcessGroupGloo does not support reduce_scatter") constraints = { @@ -665,6 +708,7 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None: constraints=constraints, variable_batch_per_feature=True, has_weighted_tables=False, + data_type=data_type, ) @unittest.skipIf( diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 4b0aedfd6..f2b65a833 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -223,7 +223,12 @@ def copy_state_dict( if isinstance(tensor, ShardedTensor): for local_shard in tensor.local_shards(): - assert global_tensor.ndim == local_shard.tensor.ndim + assert ( + global_tensor.ndim == local_shard.tensor.ndim + ), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {local_shard.tensor.ndim}" + assert ( + global_tensor.dtype == local_shard.tensor.dtype + ), f"global tensor dtype: {global_tensor.dtype}, local tensor dtype: {local_shard.tensor.dtype}" shard_meta = local_shard.metadata t = global_tensor.detach() if t.ndim == 1: @@ -246,7 +251,13 @@ def copy_state_dict( tensor.to_local().local_shards(), tensor.to_local().local_offsets(), # pyre-ignore[16] ): - assert global_tensor.ndim == local_shard.ndim + assert ( + global_tensor.ndim == local_shard.ndim + ), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.ndim: {local_shard.ndim}" + assert ( + global_tensor.dtype == local_shard.dtype + ), f"global_tensor.dtype: {global_tensor.dtype}, local_shard.dtype: {local_shard.tensor.dtype}" + t = global_tensor.detach() local_shape = local_shard.shape if t.ndim == 1: diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index b36800d08..5dc18885a 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -52,7 +52,7 @@ ShardingType, ShardMetadata, ) -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import data_type_to_dtype, EmbeddingBagConfig from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, EmbeddingCollection, @@ -71,6 +71,7 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.test_utils import skip_if_asan_class +from torchrec.types import DataType def _test_sharding( @@ -145,12 +146,15 @@ class ConstructParameterShardingAndShardTest(MultiProcessTestBase): }, ] ), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_parameter_sharding_ebc( self, per_param_sharding: Dict[str, ParameterShardingGenerator], + data_type: DataType, ) -> None: + WORLD_SIZE = 2 embedding_bag_config = [ @@ -159,12 +163,14 @@ def test_parameter_sharding_ebc( feature_names=["feature_0"], embedding_dim=16, num_embeddings=4, + data_type=data_type, ), EmbeddingBagConfig( name="table_1", feature_names=["feature_1"], embedding_dim=16, num_embeddings=4, + data_type=data_type, ), ] @@ -213,21 +219,23 @@ def test_parameter_sharding_ebc( world_size=WORLD_SIZE, tables=embedding_bag_config, initial_state_dict={ - "embedding_bags.table_0.weight": torch.Tensor( + "embedding_bags.table_0.weight": torch.tensor( [ [1] * 16, [2] * 16, [3] * 16, [4] * 16, - ] + ], + dtype=data_type_to_dtype(data_type), ), - "embedding_bags.table_1.weight": torch.Tensor( + "embedding_bags.table_1.weight": torch.tensor( [ [101] * 16, [102] * 16, [103] * 16, [104] * 16, - ] + ], + dtype=data_type_to_dtype(data_type), ), }, kjt_input_per_rank=kjt_input_per_rank, @@ -237,13 +245,18 @@ def test_parameter_sharding_ebc( class ConstructParameterShardingTest(unittest.TestCase): - def test_construct_module_sharding_plan(self) -> None: + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_construct_module_sharding_plan(self, data_type: DataType) -> None: + embedding_bag_config = [ EmbeddingBagConfig( name=f"table_{idx}", feature_names=[f"feature_{idx}"], embedding_dim=256, num_embeddings=32 * 32, + data_type=data_type, ) for idx in range(6) ] @@ -679,13 +692,18 @@ def test_construct_module_sharding_plan(self) -> None: ) self.assertDictEqual(expected, module_sharding_plan) - def test_table_wise_set_device(self) -> None: + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_table_wise_set_device(self, data_type: DataType) -> None: + embedding_bag_config = [ EmbeddingBagConfig( name=f"table_{idx}", feature_names=[f"feature_{idx}"], embedding_dim=64, num_embeddings=4096, + data_type=data_type, ) for idx in range(2) ] @@ -718,13 +736,18 @@ def test_table_wise_set_device(self) -> None: "cpu", ) - def test_row_wise_set_heterogenous_device(self) -> None: + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None: + embedding_bag_config = [ EmbeddingBagConfig( name=f"table_{idx}", feature_names=[f"feature_{idx}"], embedding_dim=64, num_embeddings=4096, + data_type=data_type, ) for idx in range(2) ] @@ -732,7 +755,10 @@ def test_row_wise_set_heterogenous_device(self) -> None: EmbeddingBagCollection(tables=embedding_bag_config), per_param_sharding={ "table_0": row_wise( - sizes_placement=([2048, 1024, 1024], ["cpu", "cuda", "cuda"]) + sizes_placement=( + [2048, 1024, 1024], + ["cpu", "cuda", "cuda"], + ) ), "table_1": row_wise( sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"]) @@ -790,13 +816,18 @@ def test_row_wise_set_heterogenous_device(self) -> None: 0, ) - def test_column_wise(self) -> None: + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_column_wise(self, data_type: DataType) -> None: + embedding_bag_config = [ EmbeddingBagConfig( name=f"table_{idx}", feature_names=[f"feature_{idx}"], embedding_dim=64, num_embeddings=4096, + data_type=data_type, ) for idx in range(2) ]