diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 03107856f..f15455089 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -559,14 +559,25 @@ def variable_batch_all2all_pooled_sync( ] with record_function("## alltoall_fwd_single ##"): - sharded_output_embeddings = torch.ops.torchrec.all_to_all_single( - sharded_input_embeddings, - output_split_sizes, - input_split_sizes, - pg_name(pg), - pg.size(), - get_gradient_division(), - ) + if pg._get_backend_name() == "fake": + sharded_output_embeddings = torch.empty( + sum(output_split_sizes), + device=sharded_input_embeddings.device, + dtype=sharded_input_embeddings.dtype, + ) + s0 = sharded_output_embeddings.size(0) + # Bad assumption that our rank GE than other + torch._check(s0 <= sharded_input_embeddings.size(0)) + sharded_output_embeddings.copy_(sharded_input_embeddings[:s0]) + else: + sharded_output_embeddings = torch.ops.torchrec.all_to_all_single( + sharded_input_embeddings, + output_split_sizes, + input_split_sizes, + pg_name(pg), + pg.size(), + get_gradient_division(), + ) if a2ai.codecs is not None: codecs = none_throws(a2ai.codecs) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 944819c61..53ce4a3b9 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -239,12 +239,25 @@ def __init__( # https://github.com/pytorch/pytorch/issues/122788 with record_function("## all2all_data:kjt splits ##"): input_tensor = torch.stack(input_tensors, dim=1).flatten() - self._output_tensor = dist._functional_collectives.all_to_all_single( - input_tensor, - output_split_sizes=None, - input_split_sizes=None, - group=pg, - ) + if pg._get_backend_name() == "fake": + self._output_tensor = torch.empty( + [self.num_workers * len(input_tensors)], + device=input_tensors[0].device, + dtype=input_tensors[0].dtype, + ) + + self._output_tensor = input_tensor[ + : input_tensor.size(0) // 2 + ].repeat(2) + else: + self._output_tensor = ( + dist._functional_collectives.all_to_all_single( + input_tensor, + output_split_sizes=None, + input_split_sizes=None, + group=pg, + ) + ) # To avoid hasattr in _wait_impl to check self._splits_awaitable # pyre-ignore self._splits_awaitable = None @@ -342,6 +355,7 @@ def __init__( self._output_tensors: List[torch.Tensor] = [] self._awaitables: List[dist.Work] = [] self._world_size: int = self._pg.size() + rank = dist.get_rank(self._pg) for input_split, output_split, input_tensor, label in zip( input_splits, @@ -353,12 +367,28 @@ def __init__( # TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_ # https://github.com/pytorch/pytorch/issues/122788 with record_function(f"## all2all_data:kjt {label} ##"): - output_tensor = dist._functional_collectives.all_to_all_single( - input_tensor, - output_split, - input_split, - pg, - ) + if self._pg._get_backend_name() == "fake": + output_tensor = torch.empty( + sum(output_split), + device=self._device, + dtype=input_tensor.dtype, + ) + _l = sum(output_split[:rank]) + _r = _l + output_split[rank] + torch._check(_r < input_tensor.size(0)) + torch._check(_l < input_tensor.size(0)) + torch._check(_l <= _r) + torch._check(2 * (_r - _l) == output_tensor.size(0)) + output_tensor.copy_( + input_tensor[_l:_r].repeat(self._world_size) + ) + else: + output_tensor = dist._functional_collectives.all_to_all_single( + input_tensor, + output_split, + input_split, + pg, + ) self._output_tensors.append(output_tensor) else: output_tensor = torch.empty( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index fa096001c..c1e70142f 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -657,7 +657,9 @@ def __init__( broadcast_buffers=True, static_graph=True, ) - self._initialize_torch_state() + + if env.process_group and dist.get_backend(env.process_group) != "fake": + self._initialize_torch_state() # TODO[zainhuda]: support module device coming from CPU if module.device not in ["meta", "cpu"] and module.device.type not in [ diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 77d8eff7d..a5f340bd6 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -25,6 +25,8 @@ from hypothesis import given, settings, strategies as st, Verbosity from torch import distributed as dist from torch._dynamo.testing import reduce_to_scalar_loss +from torch.distributed import ProcessGroup +from torch.testing._internal.distributed.fake_pg import FakeStore from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig @@ -499,6 +501,184 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: ##### NUMERIC CHECK END ##### +def _test_compile_fake_pg_fn( + rank: int, + world_size: int, +) -> None: + sharding_type = ShardingType.TABLE_WISE.value + input_type = _InputType.SINGLE_BATCH + torch_compile_backend = "eager" + config = _TestConfig() + num_embeddings = 256 + # emb_dim must be % 4 == 0 for fbgemm + emb_dim = 12 + batch_size = 10 + num_features: int = 5 + + num_float_features: int = 8 + num_weighted_features: int = 1 + + device: torch.Device = torch.device("cuda") + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + pg: ProcessGroup = dist.distributed_c10d._get_default_group() + + topology: Topology = Topology(world_size=world_size, compute_device="cuda") + mi = TestModelInfo( + dense_device=device, + sparse_device=device, + num_features=num_features, + num_float_features=num_float_features, + num_weighted_features=num_weighted_features, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(mi.num_features) + ] + + mi.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(mi.num_weighted_features) + ] + + mi.model = _gen_model(_ModelType.EBC, mi) + mi.model.training = True + + model = mi.model + + planner = EmbeddingShardingPlanner( + topology=Topology(world_size, device.type), + constraints=None, + ) + + sharders = [ + EBCSharderFixedShardingType(sharding_type), + ECSharderFixedShardingType(sharding_type), + ] + + plan: ShardingPlan = planner.plan(model, sharders) # pyre-ignore + + def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore + return DistributedModelParallel( + m, + env=ShardingEnv(world_size, rank, pg), + plan=plan, + sharders=sharders, + device=device, + init_data_parallel=False, + ) + + dmp = _dmp(model) + dmp_compile = _dmp(model) + + # TODO: Fix some data dependent failures on subsequent inputs + n_extra_numerics_checks = config.n_extra_numerics_checks_inputs + ins = [] + + for _ in range(1 + n_extra_numerics_checks): + if input_type == _InputType.VARIABLE_BATCH: + ( + _, + local_model_inputs, + ) = ModelInput.generate_variable_batch_input( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + # pyre-ignore + tables=mi.tables, + ) + else: + ( + _, + local_model_inputs, + ) = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=mi.tables, + weighted_tables=mi.weighted_tables, + variable_batch_size=False, + ) + ins.append(local_model_inputs) + + local_model_input = ins[0][rank].to(device) + + kjt = local_model_input.idlist_features + ff = local_model_input.float_features + ff.requires_grad = True + kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=True) + + compile_input_ff = ff.clone().detach() + compile_input_ff.requires_grad = True + + torchrec.distributed.comm_ops.set_use_sync_collectives(True) + torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True) + + dmp.train(True) + dmp_compile.train(True) + + def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: + tbe = dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + return tbe.weights_dev.clone().detach() + + original_weights = get_weights(dmp) + original_weights.zero_() + original_compile_weights = get_weights(dmp_compile) + original_compile_weights.zero_() + + eager_out = dmp(kjt_ft, ff) + reduce_to_scalar_loss(eager_out).backward() + + if torch_compile_backend is None: + return + + ##### COMPILE ##### + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = True + + opt_fn = torch.compile( + dmp_compile, + backend=torch_compile_backend, + fullgraph=True, + ) + compile_out = opt_fn( + kjt_for_pt2_tracing(kjt, convert_to_vb=True), compile_input_ff + ) + torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3) + ##### COMPILE END ##### + + class TestPt2Train(MultiProcessTestBase): def disable_cuda_tf32(self) -> bool: return True @@ -580,3 +760,17 @@ def test_compile_multiprocess( config=config, torch_compile_backend=compile_backend, ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires one GPU", + ) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_compile_multiprocess_fake_pg( + self, + ) -> None: + _test_compile_fake_pg_fn( + rank=0, + world_size=2, + )