Skip to content

Commit

Permalink
Add warning log for sharded module that cannot be pipelined (#1873)
Browse files Browse the repository at this point in the history
Summary:

# context
1. In a HSNN training flow, we found the pipeline wasn't working as expected, i.e., there was no overlapping between the "input_dist" and "forward".
2. After deep diving into the pipeline, we found the [`pipelined_modules`](https://www.internalfb.com/code/fbsource/[81d3af607216]/fbcode/torchrec/distributed/train_pipeline/train_pipelines.py?lines=246) is empty. (search for "@@@===@@@===@@@ _pipelined_modules" in P1211507773).
3. Discussed with sarckk, it's likely because the sharded modules should only directly take KJT as input, without additional operations. In this case, a [KJT concatenation](https://www.internalfb.com/code/fbsource/[5e6419a66017]/fbcode/aps_models/ads/gmp/models/hsnn_model/experimental/post_match_pa_ctr_model/slimdsnn_teacher_arch_522511655_human_readable.py?lines=4822-4828) is called after the data_iter.
4. This un-desired outcome could be captured during the [`_rewrite_model`](https://www.internalfb.com/code/fbsource/[9e1cebc47bff]/fbcode/torchrec/distributed/train_pipeline/utils.py?lines=578%2C628) function. (search for "_rewrite_model sharded" in P1211507773). And we think this information is very helpful.

# changes
1. When a node.target is one of the "sharded_modules", we do expect the intention for this node/module is pipelineable.
1. So we add a "logger.warning" to warn the user that this module is actually not pipelined, and should be aware of potential un-desired outcome.

Reviewed By: joshuadeng

Differential Revision: D56051238
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Apr 12, 2024
1 parent 8417057 commit 5a63a2d
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,10 @@ def _rewrite_model( # noqa C901
dist_stream,
)
pipelined_forwards.append(child)
else:
logger.warning(
f"Module '{node.target}'' will not be pipelined, due to input modifications"
)

# JIT script unsharded modules if applicable.
if apply_jit:
Expand Down

0 comments on commit 5a63a2d

Please sign in to comment.