From 493124d4c2653cddd24d17b1a966289c6bef1e20 Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Tue, 31 Oct 2023 05:12:27 -0700 Subject: [PATCH] Improve usability of `chexify`. PiperOrigin-RevId: 578145061 --- chex/_src/asserts_chexify.py | 10 ++++++++++ chex/_src/asserts_chexify_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/chex/_src/asserts_chexify.py b/chex/_src/asserts_chexify.py index ff9565f5..88bed27c 100644 --- a/chex/_src/asserts_chexify.py +++ b/chex/_src/asserts_chexify.py @@ -189,6 +189,16 @@ def _chexified_fn(*args, **kwargs): 'Nested @chexify wrapping is disallowed. ' 'Make sure that you only wrap the function at the outermost level.') + if _ai.has_tracers((args, kwargs)): + raise RuntimeError( + '@chexify must be applied on top of all (p)jit/pmap transformations' + ' (otherwise it will result in `UnexpectedTracerError`). If you have' + ' functions that use value assertions, do not wrap them' + ' individually -- just wrap the outermost function after' + ' applying all your JAX transformations. See the example at ' + 'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' + ) + if async_check: # Check completed calls. while async_check_futures and async_check_futures[0].done(): diff --git a/chex/_src/asserts_chexify_test.py b/chex/_src/asserts_chexify_test.py index 4b74b8a5..7cd455ae 100644 --- a/chex/_src/asserts_chexify_test.py +++ b/chex/_src/asserts_chexify_test.py @@ -288,6 +288,36 @@ def fn(x, y): chexified_fn(jnp.array([2])) chexified_fn.wait_checks() # Fail: not equal + def test_wrong_order_of_wrapping(self): + + @chexify_async + def inner_fn(x, y): + asserts.assert_trees_all_equal(x, y) + return jax.tree_map(jnp.add, x, y) + + def outer_fn(x, y): + z = inner_fn(x, y) + return jax.tree_map(jnp.square, z) + + x = jnp.array([1]) + y = jnp.array([1]) + with self.assertRaisesRegex( + RuntimeError, '@chexify must be applied on top of all' + ): + jax.jit(inner_fn)(x, y) + + with self.assertRaisesRegex( + RuntimeError, '@chexify must be applied on top of all' + ): + jax.jit(outer_fn)(x, y) + + with self.assertRaisesRegex( + RuntimeError, '@chexify must be applied on top of all' + ): + jax.jit(chexify_async(outer_fn))(x, y) + + outer_fn(x, y) + class AssertsChexifyTestSuite(variants.TestCase): """Test suite for chexify assertions."""