Skip to content

Commit

Permalink
Add IR serializer for KTRegroupAsDict Module
Browse files Browse the repository at this point in the history
Summary:
# context
* previously `KTRegroupAsDict` can't really supported by torch.export (IR) because this module has an intialization step as running the first batch.
* during the export the `KTRegroupAsDict` module will be initialized by a fake_tensor which is wrong
* if we initialize the module before torch.export, the device would be an issue.
* another issue is that current torch.export [can't support conditional logic in training](https://pytorch.org/docs/stable/cond.html), where initialization step only runs once.
> torch.cond is a prototype feature in PyTorch. It has limited support for input and output types and doesn’t support training currently. Please look forward to a more stable implementation in a future version of PyTorch.

NOTE: this is more like a workaround solution

# details
* we treat the `KTRegroupAsDict` as another sparse_arch and do the model swap before and after torch.export.
* more context: D59019375

Differential Revision: D57578012
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 1, 2024
1 parent 9eb6b89 commit acd3b39
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 2 deletions.
6 changes: 6 additions & 0 deletions torchrec/ir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ class PositionWeightedModuleMetadata:
@dataclass
class PositionWeightedModuleCollectionMetadata:
max_feature_lengths: List[Tuple[str, int]]


@dataclass
class KTRegroupAsDictMetadata:
groups: List[List[str]]
keys: List[str]
60 changes: 58 additions & 2 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
# pyre-strict

import json
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Type

import torch

from torch import nn
from torchrec.ir.schema import (
EBCMetadata,
EmbeddingBagConfigMetadata,
FPEBCMetadata,
KTRegroupAsDictMetadata,
PositionWeightedModuleCollectionMetadata,
PositionWeightedModuleMetadata,
)
Expand All @@ -32,6 +32,7 @@
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor


Expand Down Expand Up @@ -125,6 +126,22 @@ def fpebc_meta_forward(
)


def kt_regroup_meta_forward(
op_module: KTRegroupAsDict, keyed_tensors: List[KeyedTensor]
) -> Dict[str, torch.Tensor]:
lengths_dict: Dict[str, int] = {}
batch_size = keyed_tensors[0].values().size(0)
for kt in keyed_tensors:
for key, length in zip(kt.keys(), kt.length_per_key()):
lengths_dict[key] = length
out_lengths: List[int] = [0] * len(op_module._groups)
for i, group in enumerate(op_module._groups):
out_lengths[i] = sum(lengths_dict[key] for key in group)
arg_list = [kt.values() for kt in keyed_tensors]
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, out_lengths)
return dict(zip(op_module._keys, outputs))


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
Expand Down Expand Up @@ -364,3 +381,42 @@ def deserialize_from_dict(
JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = (
FPEBCJsonSerializer
)


class KTRegroupAsDictJsonSerializer(JsonSerializer):
_module_cls = KTRegroupAsDict

@classmethod
def swap_meta_forward(cls, module: nn.Module) -> None:
assert isinstance(module, cls._module_cls)
# pyre-ignore
module.forward = kt_regroup_meta_forward.__get__(module, cls._module_cls)

@classmethod
def serialize_to_dict(
cls,
module: nn.Module,
) -> Dict[str, Any]:
metadata = KTRegroupAsDictMetadata(
keys=module._keys,
groups=module._groups,
)
return metadata.__dict__

@classmethod
def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten_ep: Optional[nn.Module] = None,
) -> nn.Module:
metadata = KTRegroupAsDictMetadata(**metadata_dict)
return KTRegroupAsDict(
keys=metadata.keys,
groups=metadata.groups,
)


JsonSerializer.module_to_serializer_cls["KTRegroupAsDict"] = (
KTRegroupAsDictJsonSerializer
)
88 changes: 88 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor


Expand Down Expand Up @@ -433,3 +434,90 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
self.assertEqual(len(deserialized_out), len(eager_out))
for x, y in zip(deserialized_out, eager_out):
self.assertTrue(torch.allclose(x, y))

def test_regroup_as_dict_module(self) -> None:
class Model(nn.Module):
def __init__(self, ebc, fpebc, regroup):
super().__init__()
self.ebc = ebc
self.fpebc = fpebc
self.regroup = regroup

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
kt1 = self.ebc(features)
kt2 = self.fpebc(features)
return self.regroup([kt1, kt2])

tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1", "f2"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f3", "f4"],
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f5"],
)

ebc = EmbeddingBagCollection(
tables=[tb1_config, tb3_config],
is_weighted=False,
)
max_feature_lengths = {"f3": 100, "f4": 100}
fpebc = FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
tables=[tb2_config],
is_weighted=True,
),
PositionWeightedModuleCollection(
max_feature_lengths=max_feature_lengths,
),
)
regroup = KTRegroupAsDict([["f1", "f3", "f5"], ["f2", "f4"]], ["odd", "even"])
model = Model(ebc, fpebc, regroup)

id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3", "f4", "f5"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]),
)
self.assertFalse(model.regroup._is_inited)

# Serialize EBC
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)

self.assertFalse(model.regroup._is_inited)
eager_out = model(id_list_features)
self.assertFalse(model.regroup._is_inited)

# Run forward on ExportedProgram
ep_output = ep.module()(id_list_features)
for key in eager_out.keys():
self.assertEqual(ep_output[key].shape, eager_out[key].shape)
# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
self.assertFalse(deserialized_model.regroup._is_inited)
deserialized_out = deserialized_model(id_list_features)
self.assertTrue(deserialized_model.regroup._is_inited)
for key in eager_out.keys():
self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)

0 comments on commit acd3b39

Please sign in to comment.