diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index 71657769..2d923389 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -159,10 +159,19 @@ def __call__(self, cls): getattr(base, "__dataclass_params__").frozen and not self.frozen): raise TypeError("cannot inherit non-frozen dataclass from a frozen one") + if not self.init: + # Store the custom init. We will pass init=True to dataclass to create a + # default initializer, which is necessary for flattening/unflattening. + init_fn = cls.__init__ + # Delete init method otherwise dataclass function will not create a + # default init method. + if hasattr(cls, "__init__"): + delattr(cls, "__init__") + # pytype: disable=wrong-keyword-args dcls = dataclasses.dataclass( cls, - init=self.init, + init=True, repr=self.repr, eq=self.eq, order=self.order, @@ -170,6 +179,12 @@ def __call__(self, cls): frozen=self.frozen) # pytype: enable=wrong-keyword-args + # Store the default init + dcls.__dataclass_init__ = dcls.__init__ + if not self.init: + # Re-bind the custom init + dcls.__init__ = init_fn + fields_names = set(f.name for f in dataclasses.fields(dcls)) invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES) if invalid_fields: @@ -177,6 +192,15 @@ def __call__(self, cls): f"{invalid_fields} ({dcls}).") if self.mappable_dataclass: + # Mappable dataclasses require a keyword-only init as the class __init__ + # method + if not self.init: + raise ValueError( + "Mappable dataclasses are incompatible with non-default " + + "initializers (i.e. `init=False` and `mappable_dataclass=True` " + + "flags are incompatible)." + ) + dcls = mappable_dataclass(dcls) # We remove `collection.abc.Mapping` mixin methods here to allow # fields with these names. @@ -185,7 +209,10 @@ def __call__(self, cls): delattr(dcls, attr) # delete def _from_tuple(args): - return dcls(zip(dcls.__dataclass_fields__.keys(), args)) + obj = dcls.__new__(dcls) + kwargs = dict(zip(dcls.__dataclass_fields__.keys(), args)) + dcls.__dataclass_init__(obj, **kwargs) + return obj def _to_tuple(self): return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys()) @@ -205,11 +232,10 @@ def _setstate(self, state): class_self.registered = True self.__dict__.update(state) - orig_init = dcls.__init__ - # Patch object's __init__ such that the class is registered on creation if # it is not registered on deserialization. - @functools.wraps(orig_init) + orig_init = dcls.__init__ + @functools.wraps(dcls.__init__) def _init(self, *args, **kwargs): if not class_self.registered: register_dataclass_type_with_jax_tree_util(dcls) @@ -239,8 +265,14 @@ def register_dataclass_type_with_jax_tree_util(data_class): constructable from keyword arguments corresponding to the members exposed in instance.__dict__. """ - flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1] - unflatten = lambda keys, values: data_class(**dict(zip(keys, values))) + def flatten(x): + return jax.util.unzip2(sorted(x.__dict__.items()))[::-1] + + def unflatten(keys, values): + obj = data_class.__new__(data_class) + data_class.__dataclass_init__(obj, **dict(zip(keys, values))) + return obj + try: jax.tree_util.register_pytree_node( nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten) diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index a8fb63de..4524598e 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -50,6 +50,16 @@ class FrozenDataclass(): b: pytypes.ArrayDevice +@chex_dataclass(init=False, mappable_dataclass=False) +class DataclassWithInit: + a: pytypes.ArrayDevice + b: pytypes.ArrayDevice + + def __init__(self, n): + self.a = np.eye(n) + self.b = np.zeros((n, n)) + + def dummy_dataclass(factor=1., frozen=False): class_ctor = FrozenDataclass if frozen else Dataclass return class_ctor( @@ -301,6 +311,12 @@ def testIsDataclass(self, test_type): class DataclassesTest(parameterized.TestCase): + def test_custom_initializer(self): + obj = DataclassWithInit(2) + flattened, treedef = jax.tree_util.tree_flatten(obj) + unused_unflattened_obj = jax.tree_util.tree_unflatten(treedef, flattened) + # asserts.assert_trees_all_close(obj, unflattened_obj) + @parameterized.parameters([True, False]) def test_dataclass_tree_leaves(self, frozen): obj = dummy_dataclass(frozen=frozen)