diff --git a/haiku/_src/jaxpr_info_test.py b/haiku/_src/jaxpr_info_test.py index f2ef3b35e..4eee01b1b 100644 --- a/haiku/_src/jaxpr_info_test.py +++ b/haiku/_src/jaxpr_info_test.py @@ -44,14 +44,22 @@ def add(x, y): a = jnp.zeros((12, 7)) mod = jaxpr_info.make_model_info(add)(a, a) - self.assertContentsEqual( - jaxpr_info.format_module(mod), """ + if jax.__version_info__ < (0, 4, 24): + expected = """ add sign sign in f32[12,7], out f32[12,7] cos in f32[12,7], out f32[12,7] add in f32[12,7], f32[12,7], out f32[12,7] -""") +""" + else: + expected = """ +add + sign in f32[12,7], out f32[12,7] + cos in f32[12,7], out f32[12,7] + add in f32[12,7], f32[12,7], out f32[12,7] +""" + self.assertContentsEqual(jaxpr_info.format_module(mod), expected) def test_compute_flops(self): @@ -65,14 +73,23 @@ def add(x, y): a = jnp.zeros((12, 7)) mod = jaxpr_info.make_model_info(add, compute_flops=_compute_flops)(a, a) - self.assertContentsEqual( - jaxpr_info.format_module(mod), """ + # jnp.sign implementation changed in jax v0.4.24 + if jax.__version_info__ < (0, 4, 24): + expected = """ add 252 flops sign 84 flops sign 84 flops in f32[12,7], out f32[12,7] cos 84 flops in f32[12,7], out f32[12,7] add 84 flops in f32[12,7], f32[12,7], out f32[12,7] -""") +""" + else: + expected = """ +add 252 flops + sign 84 flops in f32[12,7], out f32[12,7] + cos 84 flops in f32[12,7], out f32[12,7] + add 84 flops in f32[12,7], f32[12,7], out f32[12,7] +""" + self.assertContentsEqual(jaxpr_info.format_module(mod), expected) def test_haiku_module(self):