Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for nested chex.chexify #306

Open
nicow-elia opened this issue Sep 13, 2023 · 1 comment
Open

Allow for nested chex.chexify #306

nicow-elia opened this issue Sep 13, 2023 · 1 comment

Comments

@nicow-elia
Copy link

Hello, I have a dilemma with chexify - consider the following code:

# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
    return jnp.log(x)

@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
    return log_safe(x) / (x - 1)


def test_log_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, -1.0])
    with pytest.raises(Exception):
        log_safe(x)
        log_safe.wait_checks()

    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    assert jnp.array_equal(log_safe(x), jnp.log(x))
    log_safe.wait_checks()

def test_combo_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    with pytest.raises(Exception):
        combo_safe(x)
        combo_safe.wait_checks()

    x = jnp.array([2.0, 3.0, 4.0, 5.0])
    assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
    combo_safe.wait_checks()

If I comment out the first chexify the test_log_safe test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs. which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.

A hack in this simple scenario would be to make two versions of the function, a log_safe without the chexify decorator and a log_safe_test = chex.chexify(log_safe) and only call the log_safe_test version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?

@Edgeworth
Copy link

I'm also finding the recent-ish change to disallow multiple chexify quite difficult since my graph of functions is not a tree with a single root but I still want chexify in this situation: f() calls g() calls h(), and g() calls h()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants