diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index fd732c6..c262d40 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -159,7 +159,6 @@ def __init__( self.frozen = frozen self.kw_only = kw_only self.mappable_dataclass = mappable_dataclass - self.registered = False def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" @@ -214,13 +213,9 @@ def _replace(self, **kwargs): def _getstate(self): return self.__dict__ - class_self = self - # Patch __setstate__ to register the object on deserialization. def _setstate(self, state): - if not class_self.registered: - register_dataclass_type_with_jax_tree_util(dcls) - class_self.registered = True + register_dataclass_type_with_jax_tree_util(dcls) self.__dict__.update(state) orig_init = dcls.__init__ @@ -229,9 +224,7 @@ def _setstate(self, state): # it is not registered on deserialization. @functools.wraps(orig_init) def _init(self, *args, **kwargs): - if not class_self.registered: - register_dataclass_type_with_jax_tree_util(dcls) - class_self.registered = True + register_dataclass_type_with_jax_tree_util(dcls) return orig_init(self, *args, **kwargs) setattr(dcls, "from_tuple", _from_tuple) @@ -270,6 +263,7 @@ def _flatten_with_path(dcls): return path, keys +@functools.cache def register_dataclass_type_with_jax_tree_util(data_class): """Register an existing dataclass so JAX knows how to handle it.