Skip to content

Commit

Permalink
Minor improvement to dataclass flatten_with_path.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557824539
  • Loading branch information
hamzamerzic authored and ChexDev committed Aug 22, 2023
1 parent 21890fd commit d5b4631
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
1 change: 1 addition & 0 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _flatten_with_path(dcls):
path = []
keys = []
for k, v in sorted(dcls.__dict__.items()):
k = jax.tree_util.GetAttrKey(k)
path.append((k, v))
keys.append(k)
return path, keys
Expand Down
28 changes: 20 additions & 8 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,25 +348,37 @@ def test_dataclass_tree_map(self, frozen):
def test_tree_flatten_with_keys(self):
obj = dummy_dataclass()
keys_and_leaves, treedef = jax.tree_util.tree_flatten_with_path(obj)
self.assertEqual([k for k, _ in keys_and_leaves],
[('a', 'c'), ('a', 'd'), ('b',)])
self.assertEqual(
[k for k, _ in keys_and_leaves],
[
(jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c')),
(jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d')),
(jax.tree_util.GetAttrKey('b'),),
],
)
leaves = [l for _, l in keys_and_leaves]
new_obj = treedef.unflatten(leaves)
self.assertEqual(new_obj, obj)

def test_tree_map_with_keys(self):
obj = dummy_dataclass()
key_value_list, unused_treedef = jax.tree_util.tree_flatten_with_path(obj)
# Convert a list of key-value tuples to a dict.
flat_obj = dict(key_value_list)

def f(path, x):
value = obj
for key in path:
value = getattr(value, key)
value = flat_obj[path]
np.testing.assert_allclose(value, x)
return path

out = jax.tree_util.tree_map_with_path(f, obj)
self.assertEqual(out.a.c, ('a', 'c'))
self.assertEqual(out.a.d, ('a', 'd'))
self.assertEqual(out.b, ('b',))
self.assertEqual(
out.a.c, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c'))
)
self.assertEqual(
out.a.d, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d'))
)
self.assertEqual(out.b, (jax.tree_util.GetAttrKey('b'),))

def test_tree_map_with_keys_traversal_order(self):
# pytype: disable=wrong-arg-types
Expand Down

0 comments on commit d5b4631

Please sign in to comment.