Skip to content

Commit

Permalink
Use @cache instead of per-class boolean when registering dataclass py…
Browse files Browse the repository at this point in the history
…trees.

PiperOrigin-RevId: 561315699
  • Loading branch information
tomhennigan authored and ChexDev committed Aug 30, 2023
1 parent a373d02 commit 94a26ee
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(
self.frozen = frozen
self.kw_only = kw_only
self.mappable_dataclass = mappable_dataclass
self.registered = False

def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""
Expand Down Expand Up @@ -214,13 +213,9 @@ def _replace(self, **kwargs):
def _getstate(self):
return self.__dict__

class_self = self

# Patch __setstate__ to register the object on deserialization.
def _setstate(self, state):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
class_self.registered = True
register_dataclass_type_with_jax_tree_util(dcls)
self.__dict__.update(state)

orig_init = dcls.__init__
Expand All @@ -229,9 +224,7 @@ def _setstate(self, state):
# it is not registered on deserialization.
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
class_self.registered = True
register_dataclass_type_with_jax_tree_util(dcls)
return orig_init(self, *args, **kwargs)

setattr(dcls, "from_tuple", _from_tuple)
Expand Down Expand Up @@ -270,6 +263,7 @@ def _flatten_with_path(dcls):
return path, keys


@functools.cache
def register_dataclass_type_with_jax_tree_util(data_class):
"""Register an existing dataclass so JAX knows how to handle it.
Expand Down

0 comments on commit 94a26ee

Please sign in to comment.