forked from pytorch/torchrec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2D Parallelism in TorchRec (pytorch#2554)
Summary: In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Differential Revision: D61643328
- Loading branch information
1 parent
9fb6b8e
commit f106537
Showing
12 changed files
with
970 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.