Skip to content

Commit

Permalink
make PartiallyMaterializedTensor work with checkpointing (pytorch#2531)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2531

create ShardedTensor from PartiallyMaterializedTensor, so that KVTensorWrapper can be used for checkpointing.

Reviewed By: xunnanxu

Differential Revision: D65281052

fbshipit-source-id: ae01d762c31f5acef408c18de3f6947ae2e05528
  • Loading branch information
Yulu Jia authored and facebook-github-bot committed Nov 8, 2024
1 parent ec6a5a8 commit 9a4d8a8
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 19 deletions.
83 changes: 72 additions & 11 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
PartiallyMaterializedTensor,
)
from torch import nn
from torchrec.distributed.comm import get_local_rank, get_local_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
Expand Down Expand Up @@ -585,6 +588,27 @@ def _gen_named_parameters_by_table_ssd(
yield (table_name, weight)


def _gen_named_parameters_by_table_ssd_pmt(
emb_module: SSDTableBatchedEmbeddingBags,
table_name_to_count: Dict[str, int],
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
) -> Iterator[Tuple[str, nn.Parameter]]:
"""
Return an iterator over module parameters that are embedding tables, yielding both the table
name as well as the parameter itself. The embedding table is in the form of
PartiallyMaterializedTensor to support windowed access.
"""
pmts = emb_module.split_embedding_weights()
for table_config, pmt in zip(config.embedding_tables, pmts):
table_name = table_config.name
emb_table = pmt
weight: nn.Parameter = nn.Parameter(emb_table)
# pyre-ignore
weight._in_backward_optimizers = [EmptyFusedOptimizer()]
yield (table_name, weight)


def _gen_named_parameters_by_table_fused(
emb_module: SplitTableBatchedEmbeddingBagsCodegen,
table_name_to_count: Dict[str, int],
Expand Down Expand Up @@ -1257,7 +1281,7 @@ def __init__(
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
Expand Down Expand Up @@ -1291,11 +1315,31 @@ def state_dict(
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()

return destination
"""
Args:
no_snapshot (bool): the tensors in the returned dict are
PartiallyMaterializedTensors. this argument controls wether the
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
PartiallyMaterializedTensor has a RocksDB snapshot handle
"""
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
emb_table_config_copy,
emb_tables,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand All @@ -1308,14 +1352,16 @@ def named_parameters(
):
# hack before we support optimizer on sharded parameter level
# can delete after PEA deprecation
# pyre-ignore [6]
param = nn.Parameter(tensor)
# pyre-ignore
param._in_backward_optimizers = [EmptyFusedOptimizer()]
yield name, param

# pyre-ignore [15]
def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
Expand All @@ -1326,6 +1372,21 @@ def named_split_embedding_weights(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
Expand All @@ -1340,11 +1401,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

def split_embedding_weights(self) -> List[torch.Tensor]:
"""
Return fake tensors.
"""
return [param.data for param in self._param_per_table.values()]
# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
return self.emb_module.split_embedding_weights(no_snapshot)


class BatchedFusedEmbeddingBag(
Expand Down
8 changes: 7 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import torch
import torch.distributed as dist
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
PartiallyMaterializedTensor,
)
from torch import nn
from torch.distributed._tensor import DTensor
from torchrec.distributed.embedding_types import (
Expand Down Expand Up @@ -60,6 +63,7 @@ def get_state_dict(
List[Union[nn.Module, torch.Tensor]],
List[torch.Tensor],
List[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]],
List[PartiallyMaterializedTensor],
],
pg: Optional[dist.ProcessGroup] = None,
destination: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -99,7 +103,9 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
qbias = param[2]
param = param[0]

assert embedding_table.local_rows == param.size(0) # pyre-ignore[16]
assert embedding_table.local_rows == param.size( # pyre-ignore[16]
0
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]

if qscale is not None:
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.ssd.training import SSDTableBatchedEmbeddingBags
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
PartiallyMaterializedTensor,
)
from torch import nn

from torch.autograd.function import FunctionCtx
Expand Down Expand Up @@ -649,6 +652,18 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def get_named_split_embedding_weights_snapshot(
self,
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for emb_module in self._emb_modules:
if isinstance(emb_module, KeyValueEmbeddingBag):
yield from emb_module.get_named_split_embedding_weights_snapshot()

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()
Expand Down
50 changes: 43 additions & 7 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()

model_parallel_name_to_compute_kernel: Dict[str, str] = {}
self._model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
Expand All @@ -843,7 +843,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
[("local_tensors", []), ("local_offsets", [])]
)
model_parallel_name_to_compute_kernel[table_name] = (
self._model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

Expand Down Expand Up @@ -897,18 +897,17 @@ def _initialize_torch_state(self) -> None: # noqa
"weight", nn.Parameter(torch.empty(0))
)
if (
model_parallel_name_to_compute_kernel[table_name]
self._model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]
if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
continue

if self._output_dtensor:
assert self._model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand Down Expand Up @@ -937,6 +936,8 @@ def _initialize_torch_state(self) -> None: # noqa
)
else:
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
Expand All @@ -945,6 +946,21 @@ def _initialize_torch_state(self) -> None: # noqa
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingBagCollection,
) -> OrderedDict[str, ShardedTensor]:
# retrieve all kvstore backed tensors
ret = OrderedDict()
for (
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
if self._model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
ret[table_name] = sharded_t
return ret

def post_state_dict_hook(
module: ShardedEmbeddingBagCollection,
destination: Dict[str, torch.Tensor],
Expand All @@ -965,6 +981,26 @@ def post_state_dict_hook(
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
destination[destination_key] = d_tensor

# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
# snapshot for read access.
sharded_kvtensors = extract_sharded_kvtensors(module)
sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding in zip(module._lookups, module._embedding_shardings):
if isinstance(sharding, DpPooledEmbeddingSharding):
# unwrap DDP
lookup = lookup.module
else:
for key, v in lookup.get_named_split_embedding_weights_snapshot():
destination_key = f"{prefix}embedding_bags.{key}.weight"
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
table_name,
sharded_kvtensor,
) in sharded_kvtensors_copy.items():
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
destination[destination_key] = sharded_kvtensor

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
Expand Down

0 comments on commit 9a4d8a8

Please sign in to comment.