Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add warning log for sharded module that cannot be pipelined (#1873)
Summary: # context 1. In a HSNN training flow, we found the pipeline wasn't working as expected, i.e., there was no overlapping between the "input_dist" and "forward". 2. After deep diving into the pipeline, we found the [`pipelined_modules`](https://www.internalfb.com/code/fbsource/[81d3af607216]/fbcode/torchrec/distributed/train_pipeline/train_pipelines.py?lines=246) is empty. (search for "@@@===@@@===@@@ _pipelined_modules" in P1211507773). 3. Discussed with sarckk, it's likely because the sharded modules should only directly take KJT as input, without additional operations. In this case, a [KJT concatenation](https://www.internalfb.com/code/fbsource/[5e6419a66017]/fbcode/aps_models/ads/gmp/models/hsnn_model/experimental/post_match_pa_ctr_model/slimdsnn_teacher_arch_522511655_human_readable.py?lines=4822-4828) is called after the data_iter. 4. This un-desired outcome could be captured during the [`_rewrite_model`](https://www.internalfb.com/code/fbsource/[9e1cebc47bff]/fbcode/torchrec/distributed/train_pipeline/utils.py?lines=578%2C628) function. (search for "_rewrite_model sharded" in P1211507773). And we think this information is very helpful. # changes 1. When a node.target is one of the "sharded_modules", we do expect the intention for this node/module is pipelineable. 1. So we add a "logger.warning" to warn the user that this module is actually not pipelined, and should be aware of potential un-desired outcome. Reviewed By: joshuadeng Differential Revision: D56051238
- Loading branch information