From 1c99559804fce387eda34b0678c1f5fa8a5a96dc Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 25 Oct 2024 10:41:47 -0700 Subject: [PATCH] Add handling for JAX typed PRNG keys in chex.assert_trees_all_equal PiperOrigin-RevId: 689839632 --- chex/_src/asserts.py | 11 ++++++++++- chex/_src/asserts_chexify_test.py | 15 +++++++++++++++ chex/_src/asserts_test.py | 9 +++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 718ce2e..1cdfc23 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1551,11 +1551,20 @@ def _assert_trees_all_equal_static( AssertionError: If the leaf values actual and desired are not exactly equal. """ def assert_fn(arr_1, arr_2): + if jax.dtypes.issubdtype( + getattr(arr_1, "dtype", None), jax.dtypes.prng_key + ) and jax.dtypes.issubdtype( + getattr(arr_2, "dtype", None), jax.dtypes.prng_key + ): + assert jax.random.key_impl(arr_1) == jax.random.key_impl(arr_2) + arr_1 = jax.random.key_data(arr_1) + arr_2 = jax.random.key_data(arr_2) np.testing.assert_array_equal( _ai.jnp_to_np_array(arr_1), _ai.jnp_to_np_array(arr_2), err_msg="Error in value equality check: Values not exactly equal", - strict=strict) + strict=strict, + ) def cmp_fn(arr_1, arr_2) -> bool: try: diff --git a/chex/_src/asserts_chexify_test.py b/chex/_src/asserts_chexify_test.py index 348c50b..b4226d4 100644 --- a/chex/_src/asserts_chexify_test.py +++ b/chex/_src/asserts_chexify_test.py @@ -615,6 +615,21 @@ def fn(x, y): ): chexified_fn(tree_1, tree_2) # Fail: not equal + def test_assert_trees_all_equal_with_prng_keys(self): + @jax.jit + def fn(x, y): + asserts.assert_trees_all_equal(x, y) + return x['a'] + y['a'] + + chexified_fn = asserts_chexify.chexify(fn, async_check=False) + tree1 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(1))} + tree2 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(2))} + chexified_fn(tree1, tree1) # OK + with self.assertRaisesRegex( + AssertionError, re.escape("Trees 0 and 1 differ in leaves '('key',)'") + ): + chexified_fn(tree1, tree2) # Fail: not equal + def test_assert_trees_all_close(self): @jax.jit def fn(x, y, z): diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index fed49ce..0d0bc22 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -923,6 +923,15 @@ def test_tree_all_finite_should_fail_inf(self): with self.assertRaisesRegex(ValueError, err_msg): asserts._assert_tree_all_finite_jittable(inf_tree) + def test_assert_trees_all_equal_prng_keys(self): + tree1 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(1))} + tree2 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(2))} + asserts.assert_trees_all_equal(tree1, tree1) # OK + + err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'key\'') + with self.assertRaisesRegex(AssertionError, err_regex): + asserts.assert_trees_all_equal(tree1, tree2) # Fail: not equal + def test_assert_trees_all_equal_passes_same_tree(self): tree = { 'a': [jnp.zeros((1,))],