From 5a63a2dc5e8b6e8d1198179ddffffd19bea732dd Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Fri, 12 Apr 2024 09:29:34 -0700 Subject: [PATCH] Add warning log for sharded module that cannot be pipelined (#1873) Summary: # context 1. In a HSNN training flow, we found the pipeline wasn't working as expected, i.e., there was no overlapping between the "input_dist" and "forward". 2. After deep diving into the pipeline, we found the [`pipelined_modules`](https://www.internalfb.com/code/fbsource/[81d3af607216]/fbcode/torchrec/distributed/train_pipeline/train_pipelines.py?lines=246) is empty. (search for "@@@===@@@===@@@ _pipelined_modules" in P1211507773). 3. Discussed with sarckk, it's likely because the sharded modules should only directly take KJT as input, without additional operations. In this case, a [KJT concatenation](https://www.internalfb.com/code/fbsource/[5e6419a66017]/fbcode/aps_models/ads/gmp/models/hsnn_model/experimental/post_match_pa_ctr_model/slimdsnn_teacher_arch_522511655_human_readable.py?lines=4822-4828) is called after the data_iter. 4. This un-desired outcome could be captured during the [`_rewrite_model`](https://www.internalfb.com/code/fbsource/[9e1cebc47bff]/fbcode/torchrec/distributed/train_pipeline/utils.py?lines=578%2C628) function. (search for "_rewrite_model sharded" in P1211507773). And we think this information is very helpful. # changes 1. When a node.target is one of the "sharded_modules", we do expect the intention for this node/module is pipelineable. 1. So we add a "logger.warning" to warn the user that this module is actually not pipelined, and should be aware of potential un-desired outcome. Reviewed By: joshuadeng Differential Revision: D56051238 --- torchrec/distributed/train_pipeline/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index bb57bbf5e..114f021e2 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -636,6 +636,10 @@ def _rewrite_model( # noqa C901 dist_stream, ) pipelined_forwards.append(child) + else: + logger.warning( + f"Module '{node.target}'' will not be pipelined, due to input modifications" + ) # JIT script unsharded modules if applicable. if apply_jit: