Skip to content

Commit

Permalink
add DTensor to optimizer state dict (pytorch#2585)
Browse files Browse the repository at this point in the history
Summary:

To support 2D parallelism checkpointing, we introduce DTensor to the optimizer state dict. It is enabled through fused_params["output_dtensor"] = True, meaning when table shards are outputted in DTensor so are optimizer shards.

This diff allows us to leverage N-dimensional device meshes with support for abritrary replication/sharding groups - making checkpointing easy as DCP/Modelstore support replicated/sharded placements on a device mesh (something that is unsupported in ShardedTensor)

Differential Revision: D65555455
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Dec 9, 2024
1 parent f450c59 commit 9aab41c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 21 deletions.
55 changes: 47 additions & 8 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@
PartiallyMaterializedTensor,
)
from torch import nn
from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard
from torchrec.distributed.comm import get_local_rank, get_node_group_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
from torchrec.distributed.embedding_types import (
compute_kernel_to_embedding_location,
DTensorMetadata,
GroupedEmbeddingConfig,
)
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
from torchrec.distributed.types import (
Shard,
ShardedTensor,
Expand Down Expand Up @@ -213,6 +216,7 @@ class ShardParams:
optimizer_states: List[Optional[Tuple[torch.Tensor]]]
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]
dtensor_metadata: List[DTensorMetadata]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
Expand Down Expand Up @@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
continue
if table_config.name not in table_to_shard_params:
table_to_shard_params[table_config.name] = ShardParams(
optimizer_states=[], local_metadata=[], embedding_weights=[]
optimizer_states=[],
local_metadata=[],
embedding_weights=[],
dtensor_metadata=[],
)
optimizer_state_values = None
if optimizer_states:
Expand All @@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_to_shard_params[table_config.name].local_metadata.append(
local_metadata
)
table_to_shard_params[table_config.name].dtensor_metadata.append(
table_config.dtensor_metadata
)
table_to_shard_params[table_config.name].embedding_weights.append(weight)

seen_tables = set()
Expand Down Expand Up @@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
# pyre-ignore
def get_sharded_optim_state(
momentum_idx: int, state_key: str
) -> ShardedTensor:
) -> Union[ShardedTensor, DTensor]:
assert momentum_idx > 0
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
Expand Down Expand Up @@ -528,12 +538,41 @@ def get_sharded_optim_state(
)
)

# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
process_group=self._pg,
)
# Convert optimizer state to DTensor if enabled
if table_config.dtensor_metadata:
# if rowwise state we do Shard(0), regardless of how the table is sharded
if optim_state.dim() == 1:
stride = (1,)
placements = (
(Replicate(), DTensorShard(0))
if table_config.dtensor_metadata.mesh.ndim == 2
else (DTensorShard(0),)
)
else:
stride = table_config.dtensor_metadata.stride
placements = table_config.dtensor_metadata.placements

return DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=[x.tensor for x in momentum_local_shards],
local_offsets=[ # pyre-ignore[6]
x.metadata.shard_offsets
for x in momentum_local_shards
],
),
device_mesh=table_config.dtensor_metadata.mesh,
placements=placements,
shape=optimizer_sharded_tensor_metadata.size,
stride=stride,
run_check=False,
)
else:
# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
process_group=self._pg,
)

num_states: int = min(
# pyre-ignore
Expand Down
52 changes: 39 additions & 13 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,15 @@ def __new__(

# we calculate the total tensor size by "concat" on second tensor dimension
cat_tensor_shape = list(local_shards[0].size())
if len(local_shards) > 1: # column-wise sharding
if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[1] += shard.size()[1]

# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[0] += shard.size()[0]

wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
wrapper_shape = torch.Size(cat_tensor_shape)
chunks_meta = [
Expand Down Expand Up @@ -110,6 +115,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
aten.equal.default: cls.handle_equal,
aten.detach.default: cls.handle_detach,
aten.clone.default: cls.handle_clone,
aten.new_empty.default: cls.handle_new_empty,
}

if func in dispatcher:
Expand Down Expand Up @@ -153,18 +159,28 @@ def handle_to_copy(args, kwargs):
def handle_view(args, kwargs):
view_shape = args[1]
res_shards_list = []
if (
len(args[0].local_shards()) > 1
and args[0].storage_metadata().size[0] == view_shape[0]
and args[0].storage_metadata().size[1] == view_shape[1]
):
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
# init calls view_as() on the global tensor shape
# will fail because the view shape is not applicable to individual shards.
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
if len(args[0].local_shards()) > 1:
if args[0].local_shards()[0].ndim == 2:
assert (
args[0].storage_metadata().size[0] == view_shape[0]
and args[0].storage_metadata().size[1] == view_shape[1]
)
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
# init calls view_as() on the global tensor shape
# will fail because the view shape is not applicable to individual shards.
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
elif args[0].local_shards()[0].ndim == 1:
assert args[0].storage_metadata().size[0] == view_shape[0]
# This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
else:
raise NotImplementedError("No support for view on tensors ndim > 2")
else:
# view is called per shard
res_shards_list = [
Expand Down Expand Up @@ -220,6 +236,16 @@ def handle_clone(args, kwargs):
]
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())

@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_new_empty(args, kwargs):
self_ls = args[0]
return LocalShardsWrapper(
[torch.empty_like(shard) for shard in self_ls._local_shards],
self_ls.local_offsets(),
)

@property
def device(self) -> torch._C.device: # type: ignore[override]
return (
Expand Down
20 changes: 20 additions & 0 deletions torchrec/optim/keyed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from torch import optim
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.tensor import DTensor
from torchrec.distributed.shards_wrapper import LocalShardsWrapper


OptimizerFactory = Callable[[List[Union[torch.Tensor, ShardedTensor]]], optim.Optimizer]
Expand Down Expand Up @@ -134,6 +136,24 @@ def _update_param_state_dict_object(
)
for shard, new_shard in zip(v.local_shards(), new_v.local_shards()):
shard.tensor.detach().copy_(new_shard.tensor)
elif isinstance(v, DTensor):
assert isinstance(new_v, DTensor)
# pyre-ignore[16]
if isinstance(v.to_local(), LocalShardsWrapper):
assert isinstance(new_v.to_local(), LocalShardsWrapper)
num_shards = len(v.to_local().local_shards())
num_new_shards = len(new_v.to_local().local_shards())
if num_shards != num_new_shards:
raise ValueError(
f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}"
)
for shard, new_shard in zip(
v.to_local().local_shards(), new_v.to_local().local_shards()
):
shard.detach().copy_(new_shard)
else:
assert isinstance(new_v.to_local(), torch.Tensor)
v.detach().copy_(new_v)
elif isinstance(v, torch.Tensor):
v.detach().copy_(new_v)
else:
Expand Down

0 comments on commit 9aab41c

Please sign in to comment.