diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index 096acf8..fd732c6 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -264,6 +264,7 @@ def _flatten_with_path(dcls): path = [] keys = [] for k, v in sorted(dcls.__dict__.items()): + k = jax.tree_util.GetAttrKey(k) path.append((k, v)) keys.append(k) return path, keys diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index 5bf8266..938add5 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -348,25 +348,37 @@ def test_dataclass_tree_map(self, frozen): def test_tree_flatten_with_keys(self): obj = dummy_dataclass() keys_and_leaves, treedef = jax.tree_util.tree_flatten_with_path(obj) - self.assertEqual([k for k, _ in keys_and_leaves], - [('a', 'c'), ('a', 'd'), ('b',)]) + self.assertEqual( + [k for k, _ in keys_and_leaves], + [ + (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c')), + (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d')), + (jax.tree_util.GetAttrKey('b'),), + ], + ) leaves = [l for _, l in keys_and_leaves] new_obj = treedef.unflatten(leaves) self.assertEqual(new_obj, obj) def test_tree_map_with_keys(self): obj = dummy_dataclass() + key_value_list, unused_treedef = jax.tree_util.tree_flatten_with_path(obj) + # Convert a list of key-value tuples to a dict. + flat_obj = dict(key_value_list) + def f(path, x): - value = obj - for key in path: - value = getattr(value, key) + value = flat_obj[path] np.testing.assert_allclose(value, x) return path out = jax.tree_util.tree_map_with_path(f, obj) - self.assertEqual(out.a.c, ('a', 'c')) - self.assertEqual(out.a.d, ('a', 'd')) - self.assertEqual(out.b, ('b',)) + self.assertEqual( + out.a.c, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c')) + ) + self.assertEqual( + out.a.d, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d')) + ) + self.assertEqual(out.b, (jax.tree_util.GetAttrKey('b'),)) def test_tree_map_with_keys_traversal_order(self): # pytype: disable=wrong-arg-types