Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for custom __init__ methods in dataclasses. #215

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,48 @@ def __call__(self, cls):
getattr(base, "__dataclass_params__").frozen and not self.frozen):
raise TypeError("cannot inherit non-frozen dataclass from a frozen one")

if not self.init:
# Store the custom init. We will pass init=True to dataclass to create a
# default initializer, which is necessary for flattening/unflattening.
init_fn = cls.__init__
# Delete init method otherwise dataclass function will not create a
# default init method.
if hasattr(cls, "__init__"):
delattr(cls, "__init__")

# pytype: disable=wrong-keyword-args
dcls = dataclasses.dataclass(
cls,
init=self.init,
init=True,
repr=self.repr,
eq=self.eq,
order=self.order,
unsafe_hash=self.unsafe_hash,
frozen=self.frozen)
# pytype: enable=wrong-keyword-args

# Store the default init
dcls.__dataclass_init__ = dcls.__init__
if not self.init:
# Re-bind the custom init
dcls.__init__ = init_fn

fields_names = set(f.name for f in dataclasses.fields(dcls))
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({dcls}).")

if self.mappable_dataclass:
# Mappable dataclasses require a keyword-only init as the class __init__
# method
if not self.init:
raise ValueError(
"Mappable dataclasses are incompatible with non-default " +
"initializers (i.e. `init=False` and `mappable_dataclass=True` " +
"flags are incompatible)."
)

dcls = mappable_dataclass(dcls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
Expand All @@ -185,7 +209,10 @@ def __call__(self, cls):
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
obj = dcls.__new__(dcls)
kwargs = dict(zip(dcls.__dataclass_fields__.keys(), args))
dcls.__dataclass_init__(obj, **kwargs)
return obj

def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
Expand All @@ -205,11 +232,10 @@ def _setstate(self, state):
class_self.registered = True
self.__dict__.update(state)

orig_init = dcls.__init__

# Patch object's __init__ such that the class is registered on creation if
# it is not registered on deserialization.
@functools.wraps(orig_init)
orig_init = dcls.__init__
@functools.wraps(dcls.__init__)
def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
Expand Down Expand Up @@ -239,8 +265,14 @@ def register_dataclass_type_with_jax_tree_util(data_class):
constructable from keyword arguments corresponding to the members exposed
in instance.__dict__.
"""
flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1]
unflatten = lambda keys, values: data_class(**dict(zip(keys, values)))
def flatten(x):
return jax.util.unzip2(sorted(x.__dict__.items()))[::-1]

def unflatten(keys, values):
obj = data_class.__new__(data_class)
data_class.__dataclass_init__(obj, **dict(zip(keys, values)))
return obj

try:
jax.tree_util.register_pytree_node(
nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten)
Expand Down
16 changes: 16 additions & 0 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ class FrozenDataclass():
b: pytypes.ArrayDevice


@chex_dataclass(init=False, mappable_dataclass=False)
class DataclassWithInit:
a: pytypes.ArrayDevice
b: pytypes.ArrayDevice

def __init__(self, n):
self.a = np.eye(n)
self.b = np.zeros((n, n))


def dummy_dataclass(factor=1., frozen=False):
class_ctor = FrozenDataclass if frozen else Dataclass
return class_ctor(
Expand Down Expand Up @@ -301,6 +311,12 @@ def testIsDataclass(self, test_type):

class DataclassesTest(parameterized.TestCase):

def test_custom_initializer(self):
obj = DataclassWithInit(2)
flattened, treedef = jax.tree_util.tree_flatten(obj)
unused_unflattened_obj = jax.tree_util.tree_unflatten(treedef, flattened)
# asserts.assert_trees_all_close(obj, unflattened_obj)

@parameterized.parameters([True, False])
def test_dataclass_tree_leaves(self, frozen):
obj = dummy_dataclass(frozen=frozen)
Expand Down