Skip to content

Commit

Permalink
Explicitly provide serialized_type_name in registering JaggedTensor p…
Browse files Browse the repository at this point in the history
…ytree (#1513)

Summary:
Pull Request resolved: #1513

In case JaggedTensor is moved to other directory and breaks IR compatibility.

Context: https://docs.google.com/document/d/1pLgEyTH-8VwpSXAwpJoK5xtoUYETepIt6zKOxg5rp9o/edit#bookmark=id.yrehblflecei

Reviewed By: angelayi

Differential Revision: D51312977

fbshipit-source-id: 3aa730a15c9d4e8970c2af5239296cf276f169d4
  • Loading branch information
wilson100hong authored and facebook-github-bot committed Nov 15, 2023
1 parent 9545c5a commit 5a4daed
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,12 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
return [getattr(t, a) for a in JaggedTensor._fields]


_register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
_register_pytree_node(
JaggedTensor,
_jt_flatten,
_jt_unflatten,
serialized_type_name=f"{JaggedTensor.__module__}.{JaggedTensor.__name__}",
)
register_pytree_flatten_spec(JaggedTensor, _jt_flatten_spec)


Expand Down

0 comments on commit 5a4daed

Please sign in to comment.