Skip to content

Commit

Permalink
Improve usability of chexify.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578145061
  • Loading branch information
hbq1 authored and ChexDev committed Oct 31, 2023
1 parent 3c9af1b commit 493124d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
10 changes: 10 additions & 0 deletions chex/_src/asserts_chexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ def _chexified_fn(*args, **kwargs):
'Nested @chexify wrapping is disallowed. '
'Make sure that you only wrap the function at the outermost level.')

if _ai.has_tracers((args, kwargs)):
raise RuntimeError(
'@chexify must be applied on top of all (p)jit/pmap transformations'
' (otherwise it will result in `UnexpectedTracerError`). If you have'
' functions that use value assertions, do not wrap them'
' individually -- just wrap the outermost function after'
' applying all your JAX transformations. See the example at '
'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions'
)

if async_check:
# Check completed calls.
while async_check_futures and async_check_futures[0].done():
Expand Down
30 changes: 30 additions & 0 deletions chex/_src/asserts_chexify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,36 @@ def fn(x, y):
chexified_fn(jnp.array([2]))
chexified_fn.wait_checks() # Fail: not equal

def test_wrong_order_of_wrapping(self):

@chexify_async
def inner_fn(x, y):
asserts.assert_trees_all_equal(x, y)
return jax.tree_map(jnp.add, x, y)

def outer_fn(x, y):
z = inner_fn(x, y)
return jax.tree_map(jnp.square, z)

x = jnp.array([1])
y = jnp.array([1])
with self.assertRaisesRegex(
RuntimeError, '@chexify must be applied on top of all'
):
jax.jit(inner_fn)(x, y)

with self.assertRaisesRegex(
RuntimeError, '@chexify must be applied on top of all'
):
jax.jit(outer_fn)(x, y)

with self.assertRaisesRegex(
RuntimeError, '@chexify must be applied on top of all'
):
jax.jit(chexify_async(outer_fn))(x, y)

outer_fn(x, y)


class AssertsChexifyTestSuite(variants.TestCase):
"""Test suite for chexify assertions."""
Expand Down

0 comments on commit 493124d

Please sign in to comment.