Skip to content

Commit

Permalink
Skip test with problematic jaxlib version
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 29, 2023
1 parent 557bf36 commit c31b4c2
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/test_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.lax as lax
import jax.lib
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down Expand Up @@ -514,6 +515,10 @@ def loop(init_xs):
assert speed < 1


@pytest.mark.skipif(
jax.lib.__version__ == "0.4.16", # pyright: ignore
reason="jaxlib bug; see https://github.com/google/jax/pull/17724",
)
# This isn't testing any particular failure mode: just that things generally work.
def test_speed_grad_checkpointed_while(getkey):
mlp = eqx.nn.MLP(2, 1, 2, 2, key=getkey())
Expand Down

0 comments on commit c31b4c2

Please sign in to comment.