forked from pytorch/torchrec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add DTensor to optimizer state dict (pytorch#2585)
Summary: To support 2D parallelism checkpointing, we introduce DTensor to the optimizer state dict. It is enabled through fused_params["output_dtensor"] = True, meaning when table shards are outputted in DTensor so are optimizer shards. This diff allows us to leverage N-dimensional device meshes with support for abritrary replication/sharding groups - making checkpointing easy as DCP/Modelstore support replicated/sharded placements on a device mesh (something that is unsupported in ShardedTensor) Differential Revision: D65555455
- Loading branch information
1 parent
f450c59
commit 9aab41c
Showing
3 changed files
with
106 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters