From 94a26ee3e4686bfeaa2b6053e8d310eefe7f8e9e Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Wed, 30 Aug 2023 05:50:20 -0700 Subject: [PATCH] Use @cache instead of per-class boolean when registering dataclass pytrees. PiperOrigin-RevId: 561315699 --- chex/_src/dataclass.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index fd732c6..c262d40 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -159,7 +159,6 @@ def __init__( self.frozen = frozen self.kw_only = kw_only self.mappable_dataclass = mappable_dataclass - self.registered = False def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" @@ -214,13 +213,9 @@ def _replace(self, **kwargs): def _getstate(self): return self.__dict__ - class_self = self - # Patch __setstate__ to register the object on deserialization. def _setstate(self, state): - if not class_self.registered: - register_dataclass_type_with_jax_tree_util(dcls) - class_self.registered = True + register_dataclass_type_with_jax_tree_util(dcls) self.__dict__.update(state) orig_init = dcls.__init__ @@ -229,9 +224,7 @@ def _setstate(self, state): # it is not registered on deserialization. @functools.wraps(orig_init) def _init(self, *args, **kwargs): - if not class_self.registered: - register_dataclass_type_with_jax_tree_util(dcls) - class_self.registered = True + register_dataclass_type_with_jax_tree_util(dcls) return orig_init(self, *args, **kwargs) setattr(dcls, "from_tuple", _from_tuple) @@ -270,6 +263,7 @@ def _flatten_with_path(dcls): return path, keys +@functools.cache def register_dataclass_type_with_jax_tree_util(data_class): """Register an existing dataclass so JAX knows how to handle it.