diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index 8f500001..d3a21b20 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -4,6 +4,7 @@ import jax import jax.lax as lax +import jax.lib import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu @@ -514,6 +515,10 @@ def loop(init_xs): assert speed < 1 +@pytest.mark.skipif( + jax.lib.__version__ == "0.4.16", # pyright: ignore + reason="jaxlib bug; see https://github.com/google/jax/pull/17724", +) # This isn't testing any particular failure mode: just that things generally work. def test_speed_grad_checkpointed_while(getkey): mlp = eqx.nn.MLP(2, 1, 2, 2, key=getkey())