diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py index 0560812bd..9f970cd6f 100644 --- a/torchrec/ir/schema.py +++ b/torchrec/ir/schema.py @@ -48,3 +48,9 @@ class PositionWeightedModuleMetadata: @dataclass class PositionWeightedModuleCollectionMetadata: max_feature_lengths: List[Tuple[str, int]] + + +@dataclass +class KTRegroupAsDictMetadata: + groups: List[List[str]] + keys: List[str] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 1c1fb79d2..24982a887 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -8,15 +8,15 @@ # pyre-strict import json -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Type import torch - from torch import nn from torchrec.ir.schema import ( EBCMetadata, EmbeddingBagConfigMetadata, FPEBCMetadata, + KTRegroupAsDictMetadata, PositionWeightedModuleCollectionMetadata, PositionWeightedModuleMetadata, ) @@ -32,6 +32,7 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -125,6 +126,22 @@ def fpebc_meta_forward( ) +def kt_regroup_meta_forward( + op_module: KTRegroupAsDict, keyed_tensors: List[KeyedTensor] +) -> Dict[str, torch.Tensor]: + lengths_dict: Dict[str, int] = {} + batch_size = keyed_tensors[0].values().size(0) + for kt in keyed_tensors: + for key, length in zip(kt.keys(), kt.length_per_key()): + lengths_dict[key] = length + out_lengths: List[int] = [0] * len(op_module._groups) + for i, group in enumerate(op_module._groups): + out_lengths[i] = sum(lengths_dict[key] for key in group) + arg_list = [kt.values() for kt in keyed_tensors] + outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, out_lengths) + return dict(zip(op_module._keys, outputs)) + + class JsonSerializer(SerializerInterface): """ Serializer for torch.export IR using json. @@ -364,3 +381,42 @@ def deserialize_from_dict( JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = ( FPEBCJsonSerializer ) + + +class KTRegroupAsDictJsonSerializer(JsonSerializer): + _module_cls = KTRegroupAsDict + + @classmethod + def swap_meta_forward(cls, module: nn.Module) -> None: + assert isinstance(module, cls._module_cls) + # pyre-ignore + module.forward = kt_regroup_meta_forward.__get__(module, cls._module_cls) + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + metadata = KTRegroupAsDictMetadata( + keys=module._keys, + groups=module._groups, + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = KTRegroupAsDictMetadata(**metadata_dict) + return KTRegroupAsDict( + keys=metadata.keys, + groups=metadata.groups, + ) + + +JsonSerializer.module_to_serializer_cls["KTRegroupAsDict"] = ( + KTRegroupAsDictJsonSerializer +) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 746708355..c250c6f2d 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -30,6 +30,7 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -433,3 +434,90 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: self.assertEqual(len(deserialized_out), len(eager_out)) for x, y in zip(deserialized_out, eager_out): self.assertTrue(torch.allclose(x, y)) + + def test_regroup_as_dict_module(self) -> None: + class Model(nn.Module): + def __init__(self, ebc, fpebc, regroup): + super().__init__() + self.ebc = ebc + self.fpebc = fpebc + self.regroup = regroup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + kt1 = self.ebc(features) + kt2 = self.fpebc(features) + return self.regroup([kt1, kt2]) + + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1", "f2"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f3", "f4"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f5"], + ) + + ebc = EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=False, + ) + max_feature_lengths = {"f3": 100, "f4": 100} + fpebc = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb2_config], + is_weighted=True, + ), + PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + ), + ) + regroup = KTRegroupAsDict([["f1", "f3", "f5"], ["f2", "f4"]], ["odd", "even"]) + model = Model(ebc, fpebc, regroup) + + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + self.assertFalse(model.regroup._is_inited) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + self.assertFalse(model.regroup._is_inited) + eager_out = model(id_list_features) + self.assertFalse(model.regroup._is_inited) + + # Run forward on ExportedProgram + ep_output = ep.module()(id_list_features) + for key in eager_out.keys(): + self.assertEqual(ep_output[key].shape, eager_out[key].shape) + # Deserialize EBC + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + self.assertFalse(deserialized_model.regroup._is_inited) + deserialized_out = deserialized_model(id_list_features) + self.assertTrue(deserialized_model.regroup._is_inited) + for key in eager_out.keys(): + self.assertEqual(deserialized_out[key].shape, eager_out[key].shape)