diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 6e6afbe73..dbc8829c5 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -571,7 +571,12 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten return [getattr(t, a) for a in JaggedTensor._fields] -_register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten) +_register_pytree_node( + JaggedTensor, + _jt_flatten, + _jt_unflatten, + serialized_type_name=f"{JaggedTensor.__module__}.{JaggedTensor.__name__}", +) register_pytree_flatten_spec(JaggedTensor, _jt_flatten_spec)