From 5fd2811657258fd6bb76a60e8238dbfce7c1a448 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Fri, 3 Jan 2025 11:34:41 +0000 Subject: [PATCH] 2025-01-03 nightly release (00d8ed2d6269c243913220119192d2f89dcb93ed) --- torchrec/distributed/embeddingbag.py | 12 + torchrec/distributed/test_utils/test_model.py | 54 ++-- .../tests/test_train_pipelines.py | 246 +++++++++--------- .../tests/test_train_pipelines_base.py | 4 +- .../tests/test_train_pipelines_utils.py | 38 +-- .../train_pipeline/train_pipelines.py | 61 ++--- torchrec/distributed/train_pipeline/utils.py | 238 ++++++++--------- torchrec/distributed/utils.py | 43 ++- 8 files changed, 377 insertions(+), 319 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 5f1ed57f7..f2079a50c 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -32,6 +32,7 @@ from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -73,6 +74,7 @@ add_params_from_parameter_sharding, append_prefix, convert_to_fbgemm_types, + create_global_tensor_shape_stride_from_metadata, maybe_annotate_embedding_event, merge_fused_params, none_throws, @@ -918,6 +920,14 @@ def _initialize_torch_state(self) -> None: # noqa ) ) else: + shape, stride = create_global_tensor_shape_stride_from_metadata( + none_throws(self.module_sharding_plan[table_name]), + ( + self._env.node_group_size + if isinstance(self._env, ShardingEnv2D) + else get_local_size(self._env.world_size) + ), + ) # empty shard case self._model_parallel_name_to_dtensor[table_name] = ( DTensor.from_local( @@ -927,6 +937,8 @@ def _initialize_torch_state(self) -> None: # noqa ), device_mesh=self._env.device_mesh, run_check=False, + shape=shape, + stride=stride, ) ) else: diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 449dd7a79..010abb459 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -1192,7 +1192,7 @@ def __init__( max_feature_lengths: Optional[Dict[str, int]] = None, feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, over_arch_clazz: Type[nn.Module] = TestOverArch, - preproc_module: Optional[nn.Module] = None, + postproc_module: Optional[nn.Module] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -1229,7 +1229,7 @@ def __init__( "dummy_ones", torch.ones(1, device=dense_device), ) - self.preproc_module = preproc_module + self.postproc_module = postproc_module def sparse_forward(self, input: ModelInput) -> KeyedTensor: return self.sparse( @@ -1256,8 +1256,8 @@ def forward( self, input: ModelInput, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if self.preproc_module: - input = self.preproc_module(input) + if self.postproc_module: + input = self.postproc_module(input) return self.dense_forward(input, self.sparse_forward(input)) @@ -1749,18 +1749,18 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]: class TestModelWithPreproc(nn.Module): """ - Basic module with up to 3 preproc modules: - - preproc on idlist_features for non-weighted EBC - - preproc on idscore_features for weighted EBC - - optional preproc on model input shared by both EBCs + Basic module with up to 3 postproc modules: + - postproc on idlist_features for non-weighted EBC + - postproc on idscore_features for weighted EBC + - optional postproc on model input shared by both EBCs Args: tables, weighted_tables, device, - preproc_module, + postproc_module, num_float_features, - run_preproc_inline, + run_postproc_inline, Example: >>> TestModelWithPreproc(tables, weighted_tables, device) @@ -1774,9 +1774,9 @@ def __init__( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], device: torch.device, - preproc_module: Optional[nn.Module] = None, + postproc_module: Optional[nn.Module] = None, num_float_features: int = 10, - run_preproc_inline: bool = False, + run_postproc_inline: bool = False, ) -> None: super().__init__() self.dense = TestDenseArch(num_float_features, device) @@ -1790,17 +1790,17 @@ def __init__( is_weighted=True, device=device, ) - self.preproc_nonweighted = TestPreprocNonWeighted() - self.preproc_weighted = TestPreprocWeighted() - self._preproc_module = preproc_module - self._run_preproc_inline = run_preproc_inline + self.postproc_nonweighted = TestPreprocNonWeighted() + self.postproc_weighted = TestPreprocWeighted() + self._postproc_module = postproc_module + self._run_postproc_inline = run_postproc_inline def forward( self, input: ModelInput, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Runs preprco for EBC and weighted EBC, optionally runs preproc for input + Runs preprco for EBC and weighted EBC, optionally runs postproc for input Args: input @@ -1809,9 +1809,9 @@ def forward( """ modified_input = input - if self._preproc_module is not None: - modified_input = self._preproc_module(modified_input) - elif self._run_preproc_inline: + if self._postproc_module is not None: + modified_input = self._postproc_module(modified_input) + elif self._run_postproc_inline: idlist_features = modified_input.idlist_features modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync( idlist_features.keys(), # pyre-ignore [6] @@ -1819,10 +1819,10 @@ def forward( idlist_features.lengths(), # pyre-ignore [16] ) - modified_idlist_features = self.preproc_nonweighted( + modified_idlist_features = self.postproc_nonweighted( modified_input.idlist_features ) - modified_idscore_features = self.preproc_weighted( + modified_idscore_features = self.postproc_weighted( modified_input.idscore_features ) ebc_out = self.ebc(modified_idlist_features[0]) @@ -1834,15 +1834,15 @@ def forward( class TestNegSamplingModule(torch.nn.Module): """ - Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing + Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing Args: extra_input has_params Example: - >>> preproc = TestNegSamplingModule(extra_input) - >>> out = preproc(in) + >>> postproc = TestNegSamplingModule(extra_input) + >>> out = postproc(in) Returns: ModelInput @@ -1906,8 +1906,8 @@ class TestPositionWeightedPreprocModule(torch.nn.Module): Args: None Example: - >>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device) - >>> out = preproc(in) + >>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device) + >>> out = postproc(in) Returns: ModelInput """ diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index b2b309ee4..bf708b1f5 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -61,7 +61,7 @@ DataLoadingThread, get_h2d_func, PipelinedForward, - PipelinedPreproc, + PipelinedPostproc, PipelineStage, SparseDataDistUtil, StageOut, @@ -935,7 +935,7 @@ def _check_output_equal( optimizer=optim_pipelined, device=self.device, execute_all_batches=True, - pipeline_preproc=True, + pipeline_postproc=True, ) for i in range(self.num_batches): @@ -957,14 +957,14 @@ def _check_output_equal( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_modules_share_preproc(self) -> None: + def test_pipeline_modules_share_postproc(self) -> None: """ - Setup: preproc module takes in input batch and returns modified + Setup: postproc module takes in input batch and returns modified input batch. EBC and weighted EBC inside model sparse arch subsequently uses this modified KJT. - Test case where single preproc module is shared by multiple sharded modules - and output of preproc module needs to be transformed in the SAME way + Test case where single postproc module is shared by multiple sharded modules + and output of postproc module needs to be transformed in the SAME way """ extra_input = ModelInput.generate( @@ -976,10 +976,10 @@ def test_pipeline_modules_share_preproc(self) -> None: randomize_indices=False, )[0].to(self.device) - preproc_module = TestNegSamplingModule( + postproc_module = TestNegSamplingModule( extra_input=extra_input, ) - model = self._setup_model(preproc_module=preproc_module) + model = self._setup_model(postproc_module=postproc_module) pipelined_model, pipeline = self._check_output_equal( model, @@ -988,22 +988,22 @@ def test_pipeline_modules_share_preproc(self) -> None: # Check that both EC and EBC pipelined self.assertEqual(len(pipeline._pipelined_modules), 2) - self.assertEqual(len(pipeline._pipelined_preprocs), 1) + self.assertEqual(len(pipeline._pipelined_postprocs), 1) # pyre-ignore @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_preproc_not_shared_with_arg_transform(self) -> None: + def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: """ - Test case where arguments to preproc module is some non-modifying - transformation of the input batch (no nested preproc modules) AND + Test case where arguments to postproc module is some non-modifying + transformation of the input batch (no nested postproc modules) AND arguments to multiple sharded modules can be derived from the output - of different preproc modules (i.e. preproc modules not shared). + of different postproc modules (i.e. postproc modules not shared). """ model = TestModelWithPreproc( - tables=self.tables[:-1], # ignore last table as preproc will remove + tables=self.tables[:-1], # ignore last table as postproc will remove weighted_tables=self.weighted_tables[:-1], # ignore last table device=self.device, ) @@ -1024,53 +1024,53 @@ def test_pipeline_preproc_not_shared_with_arg_transform(self) -> None: self.assertEqual(len(ebc.forward._args), 1) self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) - self.assertEqual(len(ebc.forward._args[0].preproc_modules), 2) + self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) self.assertIsInstance( - ebc.forward._args[0].preproc_modules[0], PipelinedPreproc + ebc.forward._args[0].postproc_modules[0], PipelinedPostproc ) - self.assertEqual(ebc.forward._args[0].preproc_modules[1], None) + self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) self.assertEqual( - pipelined_ebc.forward._args[0].preproc_modules[0], + pipelined_ebc.forward._args[0].postproc_modules[0], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_nonweighted`. - pipelined_model.module.preproc_nonweighted, + # `postproc_nonweighted`. + pipelined_model.module.postproc_nonweighted, ) self.assertEqual( - pipelined_weighted_ebc.forward._args[0].preproc_modules[0], + pipelined_weighted_ebc.forward._args[0].postproc_modules[0], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_weighted`. - pipelined_model.module.preproc_weighted, + # `postproc_weighted`. + pipelined_model.module.postproc_weighted, ) - # preproc args - self.assertEqual(len(pipeline._pipelined_preprocs), 2) + # postproc args + self.assertEqual(len(pipeline._pipelined_postprocs), 2) input_attr_names = {"idlist_features", "idscore_features"} - for i in range(len(pipeline._pipelined_preprocs)): - preproc_mod = pipeline._pipelined_preprocs[i] - self.assertEqual(len(preproc_mod._args), 1) + for i in range(len(pipeline._pipelined_postprocs)): + postproc_mod = pipeline._pipelined_postprocs[i] + self.assertEqual(len(postproc_mod._args), 1) - input_attr_name = preproc_mod._args[0].input_attrs[1] + input_attr_name = postproc_mod._args[0].input_attrs[1] self.assertTrue(input_attr_name in input_attr_names) - self.assertEqual(preproc_mod._args[0].input_attrs, ["", input_attr_name]) + self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name]) input_attr_names.remove(input_attr_name) - self.assertEqual(preproc_mod._args[0].is_getitems, [False, False]) - # no parent preproc module in FX graph - self.assertEqual(preproc_mod._args[0].preproc_modules, [None, None]) + self.assertEqual(postproc_mod._args[0].is_getitems, [False, False]) + # no parent postproc module in FX graph + self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None]) # pyre-ignore @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_preproc_recursive(self) -> None: + def test_pipeline_postproc_recursive(self) -> None: """ - Test recursive case where multiple arguments to preproc module is derived - from output of another preproc module. For example, + Test recursive case where multiple arguments to postproc module is derived + from output of another postproc module. For example, - out_a, out_b, out_c = preproc_1(input) - out_d = preproc_2(out_a, out_b) + out_a, out_b, out_c = postproc_1(input) + out_d = postproc_2(out_a, out_b) # do something with out_c out = ebc(out_d) """ @@ -1083,7 +1083,7 @@ def test_pipeline_preproc_recursive(self) -> None: randomize_indices=False, )[0].to(self.device) - preproc_module = TestNegSamplingModule( + postproc_module = TestNegSamplingModule( extra_input=extra_input, ) @@ -1091,7 +1091,7 @@ def test_pipeline_preproc_recursive(self) -> None: tables=self.tables[:-1], weighted_tables=self.weighted_tables[:-1], device=self.device, - preproc_module=preproc_module, + postproc_module=postproc_module, ) pipelined_model, pipeline = self._check_output_equal(model, self.sharding_type) @@ -1107,74 +1107,74 @@ def test_pipeline_preproc_recursive(self) -> None: self.assertEqual(len(ebc.forward._args), 1) self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) - self.assertEqual(len(ebc.forward._args[0].preproc_modules), 2) + self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) self.assertIsInstance( - ebc.forward._args[0].preproc_modules[0], PipelinedPreproc + ebc.forward._args[0].postproc_modules[0], PipelinedPostproc ) - self.assertEqual(ebc.forward._args[0].preproc_modules[1], None) + self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) self.assertEqual( - pipelined_ebc.forward._args[0].preproc_modules[0], + pipelined_ebc.forward._args[0].postproc_modules[0], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_nonweighted`. - pipelined_model.module.preproc_nonweighted, + # `postproc_nonweighted`. + pipelined_model.module.postproc_nonweighted, ) self.assertEqual( - pipelined_weighted_ebc.forward._args[0].preproc_modules[0], + pipelined_weighted_ebc.forward._args[0].postproc_modules[0], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_weighted`. - pipelined_model.module.preproc_weighted, + # `postproc_weighted`. + pipelined_model.module.postproc_weighted, ) - # preproc args - self.assertEqual(len(pipeline._pipelined_preprocs), 3) + # postproc args + self.assertEqual(len(pipeline._pipelined_postprocs), 3) # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `_preproc_module`. - parent_preproc_mod = pipelined_model.module._preproc_module + # `_postproc_module`. + parent_postproc_mod = pipelined_model.module._postproc_module - for preproc_mod in pipeline._pipelined_preprocs: + for postproc_mod in pipeline._pipelined_postprocs: # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_nonweighted`. - if preproc_mod == pipelined_model.module.preproc_nonweighted: - self.assertEqual(len(preproc_mod._args), 1) - args = preproc_mod._args[0] + # `postproc_nonweighted`. + if postproc_mod == pipelined_model.module.postproc_nonweighted: + self.assertEqual(len(postproc_mod._args), 1) + args = postproc_mod._args[0] self.assertEqual(args.input_attrs, ["", "idlist_features"]) self.assertEqual(args.is_getitems, [False, False]) - self.assertEqual(len(args.preproc_modules), 2) + self.assertEqual(len(args.postproc_modules), 2) self.assertEqual( - args.preproc_modules[0], - parent_preproc_mod, + args.postproc_modules[0], + parent_postproc_mod, ) - self.assertEqual(args.preproc_modules[1], None) + self.assertEqual(args.postproc_modules[1], None) # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_weighted`. - elif preproc_mod == pipelined_model.module.preproc_weighted: - self.assertEqual(len(preproc_mod._args), 1) - args = preproc_mod._args[0] + # `postproc_weighted`. + elif postproc_mod == pipelined_model.module.postproc_weighted: + self.assertEqual(len(postproc_mod._args), 1) + args = postproc_mod._args[0] self.assertEqual(args.input_attrs, ["", "idscore_features"]) self.assertEqual(args.is_getitems, [False, False]) - self.assertEqual(len(args.preproc_modules), 2) + self.assertEqual(len(args.postproc_modules), 2) self.assertEqual( - args.preproc_modules[0], - parent_preproc_mod, + args.postproc_modules[0], + parent_postproc_mod, ) - self.assertEqual(args.preproc_modules[1], None) - elif preproc_mod == parent_preproc_mod: - self.assertEqual(len(preproc_mod._args), 1) - args = preproc_mod._args[0] + self.assertEqual(args.postproc_modules[1], None) + elif postproc_mod == parent_postproc_mod: + self.assertEqual(len(postproc_mod._args), 1) + args = postproc_mod._args[0] self.assertEqual(args.input_attrs, [""]) self.assertEqual(args.is_getitems, [False]) - self.assertEqual(args.preproc_modules, [None]) + self.assertEqual(args.postproc_modules, [None]) # pyre-ignore @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: + def test_pipeline_invalid_postproc_inputs_has_trainable_params(self) -> None: """ - Test case where preproc module sits in front of sharded module but this cannot be + Test case where postproc module sits in front of sharded module but this cannot be safely pipelined as it contains trainable params in its child modules """ max_feature_lengths = { @@ -1184,12 +1184,12 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: "feature_3": 10, } - preproc_module = TestPositionWeightedPreprocModule( + postproc_module = TestPositionWeightedPreprocModule( max_feature_lengths=max_feature_lengths, device=self.device, ) - model = self._setup_model(preproc_module=preproc_module) + model = self._setup_model(postproc_module=postproc_module) ( sharded_model_pipelined, @@ -1203,7 +1203,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: optimizer=optim_pipelined, device=self.device, execute_all_batches=True, - pipeline_preproc=True, + pipeline_postproc=True, ) data = self._generate_data( @@ -1217,14 +1217,14 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: # Check that no modules are pipelined self.assertEqual(len(pipeline._pipelined_modules), 0) - self.assertEqual(len(pipeline._pipelined_preprocs), 0) + self.assertEqual(len(pipeline._pipelined_postprocs), 0) # pyre-ignore @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_invalid_preproc_trainable_params_recursive( + def test_pipeline_invalid_postproc_trainable_params_recursive( self, ) -> None: max_feature_lengths = { @@ -1234,7 +1234,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive( "feature_3": 10, } - preproc_module = TestPositionWeightedPreprocModule( + postproc_module = TestPositionWeightedPreprocModule( max_feature_lengths=max_feature_lengths, device=self.device, ) @@ -1243,7 +1243,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive( tables=self.tables[:-1], weighted_tables=self.weighted_tables[:-1], device=self.device, - preproc_module=preproc_module, + postproc_module=postproc_module, ) ( @@ -1258,7 +1258,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive( optimizer=optim_pipelined, device=self.device, execute_all_batches=True, - pipeline_preproc=True, + pipeline_postproc=True, ) data = self._generate_data( @@ -1271,25 +1271,25 @@ def test_pipeline_invalid_preproc_trainable_params_recursive( # Check that no modules are pipelined self.assertEqual(len(pipeline._pipelined_modules), 0) - self.assertEqual(len(pipeline._pipelined_preprocs), 0) + self.assertEqual(len(pipeline._pipelined_postprocs), 0) # pyre-ignore @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: + def test_pipeline_invalid_postproc_inputs_modify_kjt_recursive(self) -> None: """ - Test case where preproc module cannot be pipelined because at least one of args - is derived from output of another preproc module whose arg(s) cannot be derived + Test case where postproc module cannot be pipelined because at least one of args + is derived from output of another postproc module whose arg(s) cannot be derived from input batch (i.e. it has modifying transformations) """ model = TestModelWithPreproc( tables=self.tables[:-1], weighted_tables=self.weighted_tables[:-1], device=self.device, - preproc_module=None, - run_preproc_inline=True, # run preproc inline, outside a module + postproc_module=None, + run_postproc_inline=True, # run postproc inline, outside a module ) ( @@ -1304,7 +1304,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: optimizer=optim_pipelined, device=self.device, execute_all_batches=True, - pipeline_preproc=True, + pipeline_postproc=True, ) data = self._generate_data( @@ -1316,13 +1316,13 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: # Check that only weighted EBC is pipelined self.assertEqual(len(pipeline._pipelined_modules), 1) - self.assertEqual(len(pipeline._pipelined_preprocs), 1) + self.assertEqual(len(pipeline._pipelined_postprocs), 1) self.assertEqual(pipeline._pipelined_modules[0]._is_weighted, True) self.assertEqual( - pipeline._pipelined_preprocs[0], + pipeline._pipelined_postprocs[0], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_weighted`. - sharded_model_pipelined.module.preproc_weighted, + # `postproc_weighted`. + sharded_model_pipelined.module.postproc_weighted, ) # pyre-ignore @@ -1330,11 +1330,11 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_pipeline_preproc_fwd_values_cached(self) -> None: + def test_pipeline_postproc_fwd_values_cached(self) -> None: """ - Test to check that during model forward, the preproc module pipelined uses the + Test to check that during model forward, the postproc module pipelined uses the saved result from previous iteration(s) and doesn't perform duplicate work - check that fqns for ALL preproc modules are populated in the right train pipeline + check that fqns for ALL postproc modules are populated in the right train pipeline context. """ extra_input = ModelInput.generate( @@ -1346,7 +1346,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: randomize_indices=False, )[0].to(self.device) - preproc_module = TestNegSamplingModule( + postproc_module = TestNegSamplingModule( extra_input=extra_input, ) @@ -1354,7 +1354,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: tables=self.tables[:-1], weighted_tables=self.weighted_tables[:-1], device=self.device, - preproc_module=preproc_module, + postproc_module=postproc_module, ) ( @@ -1369,7 +1369,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: optimizer=optim_pipelined, device=self.device, execute_all_batches=True, - pipeline_preproc=True, + pipeline_postproc=True, ) data = self._generate_data( @@ -1382,22 +1382,22 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: # This was second context that was appended current_context = pipeline.contexts[0] - cached_results = current_context.preproc_fwd_results + cached_results = current_context.postproc_fwd_results self.assertEqual( list(cached_results.keys()), - ["_preproc_module", "preproc_nonweighted", "preproc_weighted"], + ["_postproc_module", "postproc_nonweighted", "postproc_weighted"], ) # next context cached results should be empty next_context = pipeline.contexts[1] - next_cached_results = next_context.preproc_fwd_results + next_cached_results = next_context.postproc_fwd_results self.assertEqual(len(next_cached_results), 0) # After progress, next_context should be populated pipeline.progress(dataloader) self.assertEqual( list(next_cached_results.keys()), - ["_preproc_module", "preproc_nonweighted", "preproc_weighted"], + ["_postproc_module", "postproc_nonweighted", "postproc_weighted"], ) # pyre-ignore @@ -1405,9 +1405,9 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - def test_nested_preproc(self) -> None: + def test_nested_postproc(self) -> None: """ - If preproc module is nested, we should still be able to pipeline it + If postproc module is nested, we should still be able to pipeline it """ extra_input = ModelInput.generate( tables=self.tables, @@ -1418,10 +1418,10 @@ def test_nested_preproc(self) -> None: randomize_indices=False, )[0].to(self.device) - preproc_module = TestNegSamplingModule( + postproc_module = TestNegSamplingModule( extra_input=extra_input, ) - model = self._setup_model(preproc_module=preproc_module) + model = self._setup_model(postproc_module=postproc_module) class ParentModule(nn.Module): def __init__( @@ -1446,7 +1446,7 @@ def forward( # Check that both EC and EBC pipelined self.assertEqual(len(pipeline._pipelined_modules), 2) - self.assertEqual(len(pipeline._pipelined_preprocs), 1) + self.assertEqual(len(pipeline._pipelined_postprocs), 1) class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): @@ -1808,7 +1808,7 @@ def test_pipelining(self) -> None: optim.step() non_pipelined_outputs.append(pred) - def gpu_preproc(x: StageOut) -> StageOut: + def gpu_postproc(x: StageOut) -> StageOut: return x sdd = SparseDataDistUtil[ModelInput]( @@ -1824,18 +1824,18 @@ def gpu_preproc(x: StageOut) -> StageOut: stream=torch.cuda.Stream(), ), PipelineStage( - name="gpu_preproc", - runnable=gpu_preproc, + name="gpu_postproc", + runnable=gpu_postproc, stream=torch.cuda.Stream(), ), PipelineStage( - name="gpu_preproc_1", - runnable=gpu_preproc, + name="gpu_postproc_1", + runnable=gpu_postproc, stream=torch.cuda.Stream(), ), PipelineStage( - name="gpu_preproc_2", - runnable=gpu_preproc, + name="gpu_postproc_2", + runnable=gpu_postproc, stream=torch.cuda.Stream(), ), PipelineStage( @@ -1885,7 +1885,7 @@ def test_pipeline_flush(self) -> None: model, sharding_type, kernel_type ) - def gpu_preproc(x: StageOut) -> StageOut: + def gpu_postproc(x: StageOut) -> StageOut: return x sdd = SparseDataDistUtil[ModelInput]( @@ -1901,8 +1901,8 @@ def gpu_preproc(x: StageOut) -> StageOut: stream=torch.cuda.Stream(), ), PipelineStage( - name="gpu_preproc", - runnable=gpu_preproc, + name="gpu_postproc", + runnable=gpu_postproc, stream=torch.cuda.Stream(), ), PipelineStage( @@ -2112,8 +2112,8 @@ def test_model_detach(self) -> None: self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1) # Check pipeline exhausted - preproc_input = pipeline.progress(dataloader) - self.assertIsNone(preproc_input) + postproc_input = pipeline.progress(dataloader) + self.assertIsNone(postproc_input) @unittest.skipIf( not torch.cuda.is_available(), @@ -2193,7 +2193,7 @@ def test_pipelining_prefetch( optim.step() non_pipelined_outputs.append(pred) - def gpu_preproc(x: StageOut) -> StageOut: + def gpu_postproc(x: StageOut) -> StageOut: return x sdd = SparseDataDistUtil[ModelInput]( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index d728a6d58..a5ed6e7b5 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -95,14 +95,14 @@ def _setup_model( self, model_type: Type[nn.Module] = TestSparseNN, enable_fsdp: bool = False, - preproc_module: Optional[nn.Module] = None, + postproc_module: Optional[nn.Module] = None, ) -> nn.Module: unsharded_model = model_type( tables=self.tables, weighted_tables=self.weighted_tables, dense_device=self.device, sparse_device=torch.device("meta"), - preproc_module=preproc_module, + postproc_module=postproc_module, ) if enable_fsdp: unsharded_model.over.dhn_arch.linear0 = FSDP( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 9c4e30326..f23dc0fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -24,7 +24,7 @@ _get_node_args, _rewrite_model, PipelinedForward, - PipelinedPreproc, + PipelinedPostproc, TrainPipelineContext, ) from torchrec.distributed.types import ShardingType @@ -57,16 +57,16 @@ def test_rewrite_model(self) -> None: randomize_indices=False, )[0].to(self.device) - preproc_module = TestNegSamplingModule( + postproc_module = TestNegSamplingModule( extra_input=extra_input, ) - model = self._setup_model(preproc_module=preproc_module) + model = self._setup_model(postproc_module=postproc_module) sharded_model, optim = self._generate_sharded_model_and_optimizer( model, sharding_type, kernel_type, fused_params ) - # Try to rewrite model without ignored_preproc_modules defined, EBC forwards not overwritten to PipelinedForward due to KJT modification + # Try to rewrite model without ignored_postproc_modules defined, EBC forwards not overwritten to PipelinedForward due to KJT modification _rewrite_model( model=sharded_model, batch=None, @@ -86,13 +86,13 @@ def test_rewrite_model(self) -> None: PipelinedForward, ) - # Now provide preproc module explicitly + # Now provide postproc module explicitly _rewrite_model( model=sharded_model, batch=None, context=TrainPipelineContext(), dist_stream=None, - pipeline_preproc=True, + pipeline_postproc=True, ) # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`. @@ -106,27 +106,27 @@ def test_rewrite_model(self) -> None: self.assertEqual( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `sparse`. - sharded_model.module.sparse.ebc.forward._args[0].preproc_modules[0], + sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_module`. - sharded_model.module.preproc_module, + # `postproc_module`. + sharded_model.module.postproc_module, ) self.assertEqual( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `sparse`. - sharded_model.module.sparse.weighted_ebc.forward._args[0].preproc_modules[ + sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[ 0 ], # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `preproc_module`. - sharded_model.module.preproc_module, + # `postproc_module`. + sharded_model.module.postproc_module, ) state_dict = sharded_model.state_dict() missing_keys, unexpected_keys = sharded_model.load_state_dict(state_dict) self.assertEqual(missing_keys, []) self.assertEqual(unexpected_keys, []) - def test_pipelined_preproc_state_dict(self) -> None: + def test_pipelined_postproc_state_dict(self) -> None: class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -147,8 +147,8 @@ def forward(self, x): rewritten_model = copy.deepcopy(model) # pyre-ignore[8] - rewritten_model.test_module = PipelinedPreproc( - preproc_module=rewritten_model.test_module, + rewritten_model.test_module = PipelinedPostproc( + postproc_module=rewritten_model.test_module, fqn="test_module", args=[], context=TrainPipelineContext(), @@ -173,10 +173,10 @@ def _create_model_for_snapshot_test( randomize_indices=False, )[0].to(self.device) - preproc_module = TestNegSamplingModule( + postproc_module = TestNegSamplingModule( extra_input=extra_input, ) - model = self._setup_model(preproc_module=preproc_module) + model = self._setup_model(postproc_module=postproc_module) model.to_empty(device=self.device) return model elif source_model_type == ModelType.SHARDED: @@ -195,7 +195,7 @@ def _create_model_for_snapshot_test( batch=None, context=TrainPipelineContext(), dist_stream=None, - pipeline_preproc=True, + pipeline_postproc=True, ) return model else: @@ -217,7 +217,7 @@ def _test_restore_from_snapshot( state_dict = source_model.state_dict() self.assertTrue( - f"preproc_module.{TestNegSamplingModule.TEST_BUFFER_NAME}" + f"postproc_module.{TestNegSamplingModule.TEST_BUFFER_NAME}" in state_dict.keys() ) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index e747a6283..69edabf91 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -49,7 +49,7 @@ In, Out, PipelinedForward, - PipelinedPreproc, + PipelinedPostproc, PipelineStage, PrefetchPipelinedForward, PrefetchTrainPipelineContext, @@ -319,7 +319,8 @@ def __init__( execute_all_batches: bool = True, apply_jit: bool = False, context_type: Type[TrainPipelineContext] = TrainPipelineContext, - pipeline_preproc: bool = False, + # keep for backward compatibility + pipeline_postproc: bool = False, custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, @@ -364,12 +365,12 @@ def __init__( ] = [] self._model_attached = True - self._pipeline_preproc = pipeline_preproc + self._pipeline_postproc = pipeline_postproc self._next_index: int = 0 self.contexts: Deque[TrainPipelineContext] = deque() self._pipelined_modules: List[ShardedModule] = [] - self._pipelined_preprocs: List[PipelinedPreproc] = [] + self._pipelined_postprocs: List[PipelinedPostproc] = [] self.batches: Deque[Optional[In]] = deque() self._dataloader_iter: Optional[Iterator[In]] = None self._dataloader_exhausted: bool = False @@ -425,9 +426,9 @@ def _set_module_context(self, context: TrainPipelineContext) -> None: for module in self._pipelined_modules: module.forward.set_context(context) - for preproc_module in self._pipelined_preprocs: + for postproc_module in self._pipelined_postprocs: # This ensures that next iter model fwd uses cached results - preproc_module.set_context(context) + postproc_module.set_context(context) def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool: batch, context = self.copy_batch_to_gpu(dataloader_iter) @@ -528,7 +529,7 @@ def _pipeline_model( self._pipelined_modules, self._model, self._original_forwards, - self._pipelined_preprocs, + self._pipelined_postprocs, _, ) = _rewrite_model( model=self._model, @@ -538,7 +539,7 @@ def _pipeline_model( batch=batch, apply_jit=self._apply_jit, pipelined_forward=pipelined_forward, - pipeline_preproc=self._pipeline_preproc, + pipeline_postproc=self._pipeline_postproc, ) # initializes input dist, so we can override input dist forwards self.start_sparse_data_dist(batch, context) @@ -615,16 +616,18 @@ def start_sparse_data_dist( with self._stream_context(self._data_dist_stream): _wait_for_batch(batch, self._memcpy_stream) - original_contexts = [p.get_context() for p in self._pipelined_preprocs] + original_contexts = [p.get_context() for p in self._pipelined_postprocs] # Temporarily set context for next iter to populate cache - for preproc_mod in self._pipelined_preprocs: - preproc_mod.set_context(context) + for postproc_mod in self._pipelined_postprocs: + postproc_mod.set_context(context) _start_data_dist(self._pipelined_modules, batch, context) # Restore context for model fwd - for module, context in zip(self._pipelined_preprocs, original_contexts): + for module, context in zip( + self._pipelined_postprocs, original_contexts + ): module.set_context(context) def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: @@ -728,7 +731,7 @@ def __init__( apply_jit: bool = False, start_batch: int = 900, stash_gradients: bool = False, - pipeline_preproc: bool = False, + pipeline_postproc: bool = True, custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, @@ -740,7 +743,7 @@ def __init__( execute_all_batches=execute_all_batches, apply_jit=apply_jit, context_type=EmbeddingTrainPipelineContext, - pipeline_preproc=pipeline_preproc, + pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, ) self._start_batch = start_batch @@ -837,7 +840,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: iteration: int = context.index or 0 losses, output = self._mlp_forward(cast(In, batch), context) - # After this point, pipelined preproc/module forward won't be called + # After this point, pipelined postproc/module forward won't be called # so we can advance their contexts to the context of the next batch already # and also pop batch and context from self.batches and self.contexts self.dequeue_batch() @@ -845,8 +848,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # batch no longer needed - delete to free up memory del batch - # cached preproc fwd results no longer needed - delete to free up memory - del context.preproc_fwd_results + # cached postproc fwd results no longer needed - delete to free up memory + del context.postproc_fwd_results # batch i+3 self.enqueue_batch(dataloader_iter) @@ -963,9 +966,9 @@ def start_sparse_data_dist( return # Temporarily set context for next iter to populate cache - original_contexts = [p.get_context() for p in self._pipelined_preprocs] - for preproc_mod in self._pipelined_preprocs: - preproc_mod.set_context(context) + original_contexts = [p.get_context() for p in self._pipelined_postprocs] + for postproc_mod in self._pipelined_postprocs: + postproc_mod.set_context(context) with record_function(f"## start_sparse_data_dist {context.index} ##"): with self._stream_context(self._data_dist_stream): @@ -977,7 +980,7 @@ def start_sparse_data_dist( context.events.append(event) # Restore context for model forward - for module, context in zip(self._pipelined_preprocs, original_contexts): + for module, context in zip(self._pipelined_postprocs, original_contexts): module.set_context(context) def start_embedding_lookup( @@ -1044,7 +1047,7 @@ def __init__( device: torch.device, execute_all_batches: bool = True, apply_jit: bool = False, - pipeline_preproc: bool = False, + pipeline_postproc: bool = True, custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, @@ -1056,7 +1059,7 @@ def __init__( execute_all_batches=execute_all_batches, apply_jit=apply_jit, context_type=PrefetchTrainPipelineContext, - pipeline_preproc=pipeline_preproc, + pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, ) self._context = PrefetchTrainPipelineContext(version=0) @@ -1273,8 +1276,8 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]): calling each of the pipeline stages in order. In the example below a fully synchronous will expose the `data_copy` and - `gpu_preproc` calls. After pipelining, the `data_copy` of batch i+2 can be - overlapped with the `gpu_preproc` of batch i+1 and the main model processing of + `gpu_postproc` calls. After pipelining, the `data_copy` of batch i+2 can be + overlapped with the `gpu_postproc` of batch i+1 and the main model processing of batch i. Args: @@ -1295,8 +1298,8 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]): stream=torch.cuda.Stream(), ), PipelineStage( - name="gpu_preproc", - runnable=gpu_preproc, + name="gpu_postproc", + runnable=gpu_postproc, stream=torch.cuda.Stream(), ), ] @@ -1556,7 +1559,7 @@ def __init__( execute_all_batches: bool = True, apply_jit: bool = False, context_type: Type[TrainPipelineContext] = TrainPipelineContext, - pipeline_preproc: bool = False, + pipeline_postproc: bool = False, custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, @@ -1568,7 +1571,7 @@ def __init__( execute_all_batches, apply_jit, context_type, - pipeline_preproc, + pipeline_postproc, custom_model_fwd, ) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 25aa9fe96..70d448acd 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -115,7 +115,7 @@ class TrainPipelineContext: field(default_factory=list) ) events: List[torch.Event] = field(default_factory=list) - preproc_fwd_results: Dict[str, Any] = field(default_factory=dict) + postproc_fwd_results: Dict[str, Any] = field(default_factory=dict) index: Optional[int] = None version: int = ( 0 # 1 is current version, 0 is deprecated but supported for backward compatibility @@ -150,7 +150,7 @@ class EmbeddingTrainPipelineContext(TrainPipelineContext): class PipelineStage: """ A pipeline stage represents a transform to an input that is independent of the - backwards() of the model. Examples include batch H2D transfer, GPU preproc, or + backwards() of the model. Examples include batch H2D transfer, GPU postproc, or gradient-less model processing. Args: @@ -177,17 +177,17 @@ class ArgInfo: input_attrs (List[str]): attributes of input batch, e.g. `batch.attr1.attr2` will produce ["attr1", "attr2"]. is_getitems (List[bool]): `batch[attr1].attr2` will produce [True, False]. - preproc_modules (List[Optional[PipelinedPreproc]]): list of torch.nn.Modules that + postproc_modules (List[Optional[PipelinedPostproc]]): list of torch.nn.Modules that transform the input batch. - constants: constant arguments that are passed to preproc modules. + constants: constant arguments that are passed to postproc modules. name (Optional[str]): name for kwarg of pipelined forward() call or None for a positional arg. """ input_attrs: List[str] is_getitems: List[bool] - # recursive dataclass as preproc_modules.args -> arginfo.preproc_modules -> so on - preproc_modules: List[Optional["PipelinedPreproc"]] + # recursive dataclass as postproc_modules.args -> arginfo.postproc_modules -> so on + postproc_modules: List[Optional["PipelinedPostproc"]] constants: List[Optional[object]] name: Optional[str] @@ -203,20 +203,20 @@ def _build_args_kwargs( for arg_info in fwd_args: if arg_info.input_attrs: arg = initial_input - for attr, is_getitem, preproc_mod, obj in zip( + for attr, is_getitem, postproc_mod, obj in zip( arg_info.input_attrs, arg_info.is_getitems, - arg_info.preproc_modules, + arg_info.postproc_modules, arg_info.constants, ): if obj is not None: arg = obj break - elif preproc_mod is not None: - # preproc will internally run the same logic recursively - # if its args are derived from other preproc modules - # we can get all inputs to preproc mod based on its recorded args_info + arg passed to it - arg = preproc_mod(arg) + elif postproc_mod is not None: + # postproc will internally run the same logic recursively + # if its args are derived from other postproc modules + # we can get all inputs to postproc mod based on its recorded args_info + arg passed to it + arg = postproc_mod(arg) else: if is_getitem: arg = arg[attr] @@ -264,32 +264,32 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: return None -class PipelinedPreproc(torch.nn.Module): +class PipelinedPostproc(torch.nn.Module): """ - Wrapper around preproc module found during model graph traversal for sparse data dist + Wrapper around postproc module found during model graph traversal for sparse data dist pipelining. In addition to the original module, it encapsulates information needed for execution such as list of ArgInfo and the current training pipeline context. Args: - preproc_module (torch.nn.Module): preproc module to run - fqn (str): fqn of the preproc module in the model being pipelined - args (List[ArgInfo]): list of ArgInfo for the preproc module + postproc_module (torch.nn.Module): postproc module to run + fqn (str): fqn of the postproc module in the model being pipelined + args (List[ArgInfo]): list of ArgInfo for the postproc module context (TrainPipelineContext): Training context for the next iteration / batch Returns: Any Example: - preproc = PipelinedPreproc(preproc_module, fqn, args, context) - # module-swap with pipeliend preproc - setattr(model, fqn, preproc) + postproc = PipelinedPostproc(postproc_module, fqn, args, context) + # module-swap with pipeliend postproc + setattr(model, fqn, postproc) """ _FORCE_STATE_DICT_LOAD = True def __init__( self, - preproc_module: torch.nn.Module, + postproc_module: torch.nn.Module, fqn: str, args: List[ArgInfo], context: TrainPipelineContext, @@ -298,7 +298,7 @@ def __init__( dist_stream: Optional[torch.Stream], ) -> None: super().__init__() - self._preproc_module = preproc_module + self._postproc_module = postproc_module self._fqn = fqn self._args = args self._context = context @@ -306,11 +306,11 @@ def __init__( self._dist_stream = dist_stream if not default_stream: logger.warning( - f"Preproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" + f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" ) if not dist_stream: logger.warning( - f"Preproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" + f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" ) if self._dist_stream: @@ -325,8 +325,8 @@ def __init__( self._stream_context = NoOpStream @property - def preproc_module(self) -> torch.nn.Module: - return self._preproc_module + def postproc_module(self) -> torch.nn.Module: + return self._postproc_module @property def fqn(self) -> str: @@ -341,37 +341,37 @@ def forward(self, *input, **kwargs) -> Any: Returns: Any """ - if self._fqn in self._context.preproc_fwd_results: + if self._fqn in self._context.postproc_fwd_results: # This should only be hit in two cases: # 1) During model forward # During model forward, avoid duplicate work # by returning the cached result from previous # iteration's _start_data_dist - # 2) During _start_data_dist when preproc module is + # 2) During _start_data_dist when postproc module is # shared by more than one args. e.g. if we have - # preproc_out_a = preproc_a(input) - # preproc_out_b = preproc_b(preproc_out_a) <- preproc_a shared - # preproc_out_c = preproc_c(preproc_out_a) <-^ - # When processing preproc_b, we cache value of preproc_a(input) - # so when processing preproc_c, we can reuse preproc_a(input) - res = self._context.preproc_fwd_results[self._fqn] + # postproc_out_a = postproc_a(input) + # postproc_out_b = postproc_b(postproc_out_a) <- postproc_a shared + # postproc_out_c = postproc_c(postproc_out_a) <-^ + # When processing postproc_b, we cache value of postproc_a(input) + # so when processing postproc_c, we can reuse postproc_a(input) + res = self._context.postproc_fwd_results[self._fqn] return res # Everything below should only be called during _start_data_dist stage - # Build up arg and kwargs from recursive call to pass to preproc module - # Arguments to preproc module can be also be a derived product - # of another preproc module call, as long as module is pipelineable + # Build up arg and kwargs from recursive call to pass to postproc module + # Arguments to postproc module can be also be a derived product + # of another postproc module call, as long as module is pipelineable # Use input[0] as _start_data_dist only passes 1 arg args, kwargs = _build_args_kwargs(input[0], self._args) - with record_function(f"## sdd_input_preproc {self._context.index} ##"): + with record_function(f"## sdd_input_postproc {self._context.index} ##"): # should be no-op as we call this in dist stream with self._stream_context(self._dist_stream): - res = self._preproc_module(*args, **kwargs) + res = self._postproc_module(*args, **kwargs) - # Ensure preproc modules output is safe to use from default stream later + # Ensure postproc modules output is safe to use from default stream later if self._default_stream and self._dist_stream: self._default_stream.wait_stream(self._dist_stream) @@ -387,12 +387,12 @@ def forward(self, *input, **kwargs) -> Any: recursive_record_stream(res, self._default_stream) elif self._context.index == 0: logger.warning( - f"Result of preproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" + f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" ) with self._stream_context(self._default_stream): # Cache results, only during _start_data_dist - self._context.preproc_fwd_results[self._fqn] = res + self._context.postproc_fwd_results[self._fqn] = res return res @@ -417,18 +417,18 @@ def named_modules( if self not in memo: if remove_duplicate: memo.add(self) - # This is needed because otherwise the rewrite won't find the existing preproc, and will create a new one + # This is needed because otherwise the rewrite won't find the existing postproc, and will create a new one # Also, `named_modules` need to include self - see base implementation in the nn.modules.Module yield prefix, self - # Difference from base implementation is here - the child name (_preproc_module) is not added to the prefix - yield from self._preproc_module.named_modules( + # Difference from base implementation is here - the child name (_postproc_module) is not added to the prefix + yield from self._postproc_module.named_modules( memo, prefix, remove_duplicate ) def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.nn.Parameter]]: - yield from self._preproc_module.named_parameters( + yield from self._postproc_module.named_parameters( prefix, recurse, remove_duplicate, @@ -437,7 +437,9 @@ def named_parameters( def named_buffers( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: - yield from self._preproc_module.named_buffers(prefix, recurse, remove_duplicate) + yield from self._postproc_module.named_buffers( + prefix, recurse, remove_duplicate + ) # pyre-ignore [14] def state_dict( @@ -451,7 +453,7 @@ def state_dict( destination = OrderedDict() # pyre-ignore [16] destination._metadata = OrderedDict() - self._preproc_module.state_dict( + self._postproc_module.state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars ) return destination @@ -462,7 +464,7 @@ def load_state_dict( state_dict: OrderedDict[str, torch.Tensor], strict: bool = True, ) -> _IncompatibleKeys: - return self._preproc_module.load_state_dict(state_dict, strict=strict) + return self._postproc_module.load_state_dict(state_dict, strict=strict) TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext) @@ -849,7 +851,7 @@ def _check_args_for_call_module( return False -def _check_preproc_pipelineable( +def _check_postproc_pipelineable( module: torch.nn.Module, ) -> bool: for _, _ in module.named_parameters(recurse=True): @@ -861,39 +863,39 @@ def _check_preproc_pipelineable( return True -def _find_preproc_module_recursive( +def _find_postproc_module_recursive( module: torch.nn.Module, - preproc_module_fqn: str, + postproc_module_fqn: str, ) -> Optional[torch.nn.Module]: """ - Finds the preproc module in the model. + Finds the postproc module in the model. """ for name, child in module.named_modules(): - if name == preproc_module_fqn: + if name == postproc_module_fqn: return child return None -def _swap_preproc_module_recursive( +def _swap_postproc_module_recursive( module: torch.nn.Module, to_swap_module: torch.nn.Module, - preproc_module_fqn: str, + postproc_module_fqn: str, path: str = "", ) -> torch.nn.Module: """ - Swaps the preproc module in the model. + Swaps the postproc module in the model. """ - if isinstance(module, PipelinedPreproc): + if isinstance(module, PipelinedPostproc): return module - if path == preproc_module_fqn: + if path == postproc_module_fqn: return to_swap_module for name, child in module.named_children(): - child = _swap_preproc_module_recursive( + child = _swap_postproc_module_recursive( child, to_swap_module, - preproc_module_fqn, + postproc_module_fqn, path + "." + name if path else name, ) setattr(module, name, child) @@ -906,12 +908,12 @@ def _get_node_args_helper( # pyre-ignore arguments, num_found: int, - pipelined_preprocs: Set[PipelinedPreproc], + pipelined_postprocs: Set[PipelinedPostproc], context: TrainPipelineContext, - pipeline_preproc: bool, - # Add `None` constants to arg info only for preproc modules + pipeline_postproc: bool, + # Add `None` constants to arg info only for postproc modules # Defaults to False for backward compatibility - for_preproc_module: bool = False, + for_postproc_module: bool = False, default_stream: Optional[torch.Stream] = None, dist_stream: Optional[torch.Stream] = None, ) -> Tuple[List[ArgInfo], int]: @@ -921,15 +923,15 @@ def _get_node_args_helper( """ arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))] for arg, arg_info in zip(arguments, arg_info_list): - if not for_preproc_module and arg is None: + if not for_postproc_module and arg is None: num_found += 1 continue while True: if not isinstance(arg, torch.fx.Node): - if pipeline_preproc: + if pipeline_postproc: arg_info.input_attrs.insert(0, "") arg_info.is_getitems.insert(0, False) - arg_info.preproc_modules.insert(0, None) + arg_info.postproc_modules.insert(0, None) if isinstance(arg, (fx_immutable_dict, fx_immutable_list)): # Make them mutable again, in case in-place updates are made arg_info.constants.insert(0, arg.copy()) @@ -953,13 +955,13 @@ def _get_node_args_helper( else: arg_info.input_attrs.append(key) arg_info.is_getitems.append(False) - arg_info.preproc_modules.append(None) + arg_info.postproc_modules.append(None) arg_info.constants.append(None) else: # no-op arg_info.input_attrs.insert(0, "") arg_info.is_getitems.insert(0, False) - arg_info.preproc_modules.insert(0, None) + arg_info.postproc_modules.insert(0, None) arg_info.constants.insert(0, None) num_found += 1 @@ -976,7 +978,7 @@ def _get_node_args_helper( # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, False) - arg_info.preproc_modules.insert(0, None) + arg_info.postproc_modules.insert(0, None) arg_info.constants.insert(0, None) arg = child_node.args[0] elif ( @@ -991,7 +993,7 @@ def _get_node_args_helper( # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, True) - arg_info.preproc_modules.insert(0, None) + arg_info.postproc_modules.insert(0, None) arg_info.constants.insert(0, None) arg = child_node.args[0] elif ( @@ -1033,43 +1035,43 @@ def _get_node_args_helper( # pyre-ignore[6] arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, True) - arg_info.preproc_modules.insert(0, None) + arg_info.postproc_modules.insert(0, None) arg_info.constants.insert(0, None) arg = child_node.args[0] elif child_node.op == "call_module": - preproc_module_fqn = str(child_node.target) - preproc_module = _find_preproc_module_recursive( - model, preproc_module_fqn + postproc_module_fqn = str(child_node.target) + postproc_module = _find_postproc_module_recursive( + model, postproc_module_fqn ) - if not pipeline_preproc: + if not pipeline_postproc: logger.warning( - f"Found module {preproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_preproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_preproc=True`" + f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" ) break - if not preproc_module: + if not postproc_module: # Could not find such module, should not happen break - if isinstance(preproc_module, PipelinedPreproc): + if isinstance(postproc_module, PipelinedPostproc): # Already did module swap and registered args, early exit arg_info.input_attrs.insert(0, "") # dummy value arg_info.is_getitems.insert(0, False) - pipelined_preprocs.add(preproc_module) - arg_info.preproc_modules.insert(0, preproc_module) + pipelined_postprocs.add(postproc_module) + arg_info.postproc_modules.insert(0, postproc_module) arg_info.constants.insert(0, None) num_found += 1 break - if not isinstance(preproc_module, torch.nn.Module): + if not isinstance(postproc_module, torch.nn.Module): logger.warning( - f"Expected preproc_module to be nn.Module but was {type(preproc_module)}" + f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" ) break # check if module is safe to pipeline i.e.no trainable param - if not _check_preproc_pipelineable(preproc_module): + if not _check_postproc_pipelineable(postproc_module): break # For module calls, `self` isn't counted @@ -1078,46 +1080,46 @@ def _get_node_args_helper( # module call without any args, assume KJT modified break - # recursive call to check that all inputs to this preproc module - # is either made of preproc module or non-modifying train batch input + # recursive call to check that all inputs to this postproc module + # is either made of postproc module or non-modifying train batch input # transformations - preproc_args, num_found_safe_preproc_args = _get_node_args( + postproc_args, num_found_safe_postproc_args = _get_node_args( model, child_node, - pipelined_preprocs, + pipelined_postprocs, context, - pipeline_preproc, + pipeline_postproc, True, default_stream=default_stream, dist_stream=dist_stream, ) - if num_found_safe_preproc_args == total_num_args: + if num_found_safe_postproc_args == total_num_args: logger.info( - f"""Module {preproc_module} is a valid preproc module (no + f"""Module {postproc_module} is a valid postproc module (no trainable params and inputs can be derived from train batch input - via a series of either valid preproc modules or non-modifying + via a series of either valid postproc modules or non-modifying transformations) and will be applied during sparse data dist stage""" ) - pipelined_preproc_module = PipelinedPreproc( - preproc_module, - preproc_module_fqn, - preproc_args, + pipelined_postproc_module = PipelinedPostproc( + postproc_module, + postproc_module_fqn, + postproc_args, context, default_stream=default_stream, dist_stream=dist_stream, ) # module swap - _swap_preproc_module_recursive( - model, pipelined_preproc_module, preproc_module_fqn + _swap_postproc_module_recursive( + model, pipelined_postproc_module, postproc_module_fqn ) arg_info.input_attrs.insert(0, "") # dummy value arg_info.is_getitems.insert(0, False) - pipelined_preprocs.add(pipelined_preproc_module) - arg_info.preproc_modules.insert(0, pipelined_preproc_module) + pipelined_postprocs.add(pipelined_postproc_module) + arg_info.postproc_modules.insert(0, pipelined_postproc_module) arg_info.constants.insert(0, None) num_found += 1 @@ -1133,10 +1135,10 @@ def _get_node_args_helper( def _get_node_args( model: torch.nn.Module, node: Node, - pipelined_preprocs: Set[PipelinedPreproc], + pipelined_postprocs: Set[PipelinedPostproc], context: TrainPipelineContext, - pipeline_preproc: bool, - for_preproc_module: bool = False, + pipeline_postproc: bool, + for_postproc_module: bool = False, default_stream: Optional[torch.Stream] = None, dist_stream: Optional[torch.Stream] = None, ) -> Tuple[List[ArgInfo], int]: @@ -1146,10 +1148,10 @@ def _get_node_args( model, node.args, num_found, - pipelined_preprocs, + pipelined_postprocs, context, - pipeline_preproc, - for_preproc_module, + pipeline_postproc, + for_postproc_module, default_stream=default_stream, dist_stream=dist_stream, ) @@ -1157,10 +1159,10 @@ def _get_node_args( model, node.kwargs.values(), num_found, - pipelined_preprocs, + pipelined_postprocs, context, - pipeline_preproc, - for_preproc_module, + pipeline_postproc, + for_postproc_module, default_stream=default_stream, dist_stream=dist_stream, ) @@ -1286,13 +1288,13 @@ def _rewrite_model( # noqa C901 batch: Optional[In] = None, apply_jit: bool = False, pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward, - pipeline_preproc: bool = False, + pipeline_postproc: bool = False, default_stream: Optional[torch.Stream] = None, ) -> Tuple[ List[ShardedModule], torch.nn.Module, List[Callable[..., Any]], - List[PipelinedPreproc], + List[PipelinedPostproc], List[str], ]: input_model = model @@ -1332,7 +1334,7 @@ def _rewrite_model( # noqa C901 pipelined_forwards = [] original_forwards = [] - pipelined_preprocs: Set[PipelinedPreproc] = set() + pipelined_postprocs: Set[PipelinedPostproc] = set() non_pipelined_sharded_modules = [] for node in graph.nodes: @@ -1343,9 +1345,9 @@ def _rewrite_model( # noqa C901 arg_info_list, num_found = _get_node_args( model, node, - pipelined_preprocs, + pipelined_postprocs, context, - pipeline_preproc, + pipeline_postproc, default_stream=default_stream, dist_stream=dist_stream, ) @@ -1386,7 +1388,7 @@ def _rewrite_model( # noqa C901 pipelined_forwards, input_model, original_forwards, - list(pipelined_preprocs), + list(pipelined_postprocs), non_pipelined_sharded_modules, ) @@ -1674,7 +1676,7 @@ def detach(self) -> torch.nn.Module: def start_sparse_data_dist(self, batch: In) -> In: if not self.initialized: # Step 1: Pipeline input dist in trec sharded modules - # TODO (yhshin): support preproc modules for `StagedTrainPipeline` + # TODO (yhshin): support postproc modules for `StagedTrainPipeline` ( self._pipelined_modules, self.model, diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 7be8c6d15..8a3db1209 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -15,7 +15,7 @@ from collections import OrderedDict from contextlib import AbstractContextManager, nullcontext from dataclasses import asdict -from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType @@ -511,3 +511,44 @@ def interaction(self, *args, **kwargs) -> None: pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin + + +def create_global_tensor_shape_stride_from_metadata( + parameter_sharding: ParameterSharding, devices_per_node: Optional[int] = None +) -> Tuple[torch.Size, Tuple[int, int]]: + """ + Create a global tensor shape and stride from shard metadata. + + Returns: + torch.Size: global tensor shape. + tuple: global tensor stride. + """ + size = None + if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value: + row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16] + col_dim = 0 + for shard in parameter_sharding.sharding_spec.shards: + col_dim += shard.shard_sizes[1] + size = torch.Size([row_dim, col_dim]) + elif ( + parameter_sharding.sharding_type == ShardingType.ROW_WISE.value + or parameter_sharding.sharding_type == ShardingType.TABLE_ROW_WISE.value + ): + row_dim = 0 + col_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[1] + for shard in parameter_sharding.sharding_spec.shards: + row_dim += shard.shard_sizes[0] + size = torch.Size([row_dim, col_dim]) + elif parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value: + size = torch.Size(parameter_sharding.sharding_spec.shards[0].shard_sizes) + elif parameter_sharding.sharding_type == ShardingType.GRID_SHARD.value: + # we need node group size to appropriately calculate global shape from shard + assert devices_per_node is not None + row_dim, col_dim = 0, 0 + num_cw_shards = len(parameter_sharding.sharding_spec.shards) // devices_per_node + for _ in range(num_cw_shards): + col_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[1] + for _ in range(devices_per_node): + row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0] + size = torch.Size([row_dim, col_dim]) + return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7]