From 5a4daed43b7d1a4d329b03a2407fb9d5a2c721c1 Mon Sep 17 00:00:00 2001 From: Wilson Hong Date: Tue, 14 Nov 2023 20:07:01 -0800 Subject: [PATCH] Explicitly provide serialized_type_name in registering JaggedTensor pytree (#1513) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1513 In case JaggedTensor is moved to other directory and breaks IR compatibility. Context: https://docs.google.com/document/d/1pLgEyTH-8VwpSXAwpJoK5xtoUYETepIt6zKOxg5rp9o/edit#bookmark=id.yrehblflecei Reviewed By: angelayi Differential Revision: D51312977 fbshipit-source-id: 3aa730a15c9d4e8970c2af5239296cf276f169d4 --- torchrec/sparse/jagged_tensor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 6e6afbe73..dbc8829c5 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -571,7 +571,12 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten return [getattr(t, a) for a in JaggedTensor._fields] -_register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten) +_register_pytree_node( + JaggedTensor, + _jt_flatten, + _jt_unflatten, + serialized_type_name=f"{JaggedTensor.__module__}.{JaggedTensor.__name__}", +) register_pytree_flatten_spec(JaggedTensor, _jt_flatten_spec)