You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I have a dilemma with chexify - consider the following code:
# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
return jnp.log(x)
@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
return log_safe(x) / (x - 1)
def test_log_safe() -> None:
x = jnp.array([1.0, 2.0, 3.0, -1.0])
with pytest.raises(Exception):
log_safe(x)
log_safe.wait_checks()
x = jnp.array([1.0, 2.0, 3.0, 4.0])
assert jnp.array_equal(log_safe(x), jnp.log(x))
log_safe.wait_checks()
def test_combo_safe() -> None:
x = jnp.array([1.0, 2.0, 3.0, 4.0])
with pytest.raises(Exception):
combo_safe(x)
combo_safe.wait_checks()
x = jnp.array([2.0, 3.0, 4.0, 5.0])
assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
combo_safe.wait_checks()
If I comment out the first chexify the test_log_safe test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs. which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.
A hack in this simple scenario would be to make two versions of the function, a log_safe without the chexify decorator and a log_safe_test = chex.chexify(log_safe) and only call the log_safe_test version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?
The text was updated successfully, but these errors were encountered:
I'm also finding the recent-ish change to disallow multiple chexify quite difficult since my graph of functions is not a tree with a single root but I still want chexify in this situation: f() calls g() calls h(), and g() calls h()
Hello, I have a dilemma with chexify - consider the following code:
If I comment out the first chexify the
test_log_safe
test will fail withRuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs.
which makes sense to me. However, once I add the decorator back in, the second test fails withRuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.
A hack in this simple scenario would be to make two versions of the function, a
log_safe
without the chexify decorator and alog_safe_test = chex.chexify(log_safe)
and only call thelog_safe_test
version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?The text was updated successfully, but these errors were encountered: