diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 9c314f6fa..732d04426 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,6 +26,7 @@ ) import torch +from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._shard.sharding_spec import EnumerableShardingSpec @@ -90,6 +91,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -97,13 +99,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - logger: logging.Logger = logging.getLogger(__name__) @@ -1205,25 +1200,50 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, + features: TypeUnion[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + # torch.distributed.breakpoint() + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(input_feature_names=features.keys()) + self._create_input_dist(input_feature_names=feature_keys) self._has_uninitialized_input_dist = False with torch.no_grad(): unpadded_features = None - if features.variable_stride_per_key(): + if ( + isinstance(features, KeyedJaggedTensor) + and features.variable_stride_per_key() + ): unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if self._features_order: + if isinstance(features, KeyedJaggedTensor) and self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `TypeUnion[Module, Tensor]`. self._features_order_tensor, ) - features_by_shards = features.split(self._feature_splits) + + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split(self._feature_splits) + else: # TensorDict + feature_names = ( + [feature_keys[i] for i in self._features_order] + if self._features_order # empty features_order means no reordering + else feature_keys + ) + feature_names = [name.split("@")[0] for name in feature_names] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + maybe_td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] if self._use_index_dedup: features_by_shards = self._dedup_indices(ctx, features_by_shards) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 4b0aedfd6..60d060ef9 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -147,6 +147,7 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -177,9 +178,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - for _ in range(num_inputs): - inputs.append( - ( + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -188,8 +189,26 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + input_type=input_type, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -200,7 +219,6 @@ def gen_model_and_input( long_indices=long_indices, ) ) - ) return (model, inputs) @@ -286,6 +304,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. @@ -308,6 +327,7 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, + input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index aec092354..d13d819c3 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,3 +376,44 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(SequenceModelParallelTest): + + def test_sharding_variable_batch(self) -> None: + pass + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type="td", + ) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 4ade3df2f..d110fd57f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -219,7 +219,10 @@ def __init__( self._feature_names: List[List[str]] = [table.feature_names for table in tables] self.reset_parameters() - def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + def forward( + self, + features: KeyedJaggedTensor, # can also take TensorDict as input + ) -> KeyedTensor: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. @@ -450,7 +453,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: KeyedJaggedTensor, # can also take TensorDict as input ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -463,6 +466,7 @@ def forward( Dict[str, JaggedTensor] """ + features = maybe_td_to_kjt(features, None) feature_embeddings: Dict[str, JaggedTensor] = {} jt_dict: Dict[str, JaggedTensor] = features.to_dict() for i, emb_module in enumerate(self.embeddings.values()):