Skip to content

Commit

Permalink
Haiku: fix jaxpr_info_test for jax v0.4.24
Browse files Browse the repository at this point in the history
The implementation of jnp.sign changed in jax-ml/jax#19390

PiperOrigin-RevId: 599194751
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jan 17, 2024
1 parent 6339353 commit 0f4a96e
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions haiku/_src/jaxpr_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,22 @@ def add(x, y):

a = jnp.zeros((12, 7))
mod = jaxpr_info.make_model_info(add)(a, a)
self.assertContentsEqual(
jaxpr_info.format_module(mod), """
if jax.__version_info__ < (0, 4, 24):
expected = """
add
sign
sign in f32[12,7], out f32[12,7]
cos in f32[12,7], out f32[12,7]
add in f32[12,7], f32[12,7], out f32[12,7]
""")
"""
else:
expected = """
add
sign in f32[12,7], out f32[12,7]
cos in f32[12,7], out f32[12,7]
add in f32[12,7], f32[12,7], out f32[12,7]
"""
self.assertContentsEqual(jaxpr_info.format_module(mod), expected)

def test_compute_flops(self):

Expand All @@ -65,14 +73,23 @@ def add(x, y):

a = jnp.zeros((12, 7))
mod = jaxpr_info.make_model_info(add, compute_flops=_compute_flops)(a, a)
self.assertContentsEqual(
jaxpr_info.format_module(mod), """
# jnp.sign implementation changed in jax v0.4.24
if jax.__version_info__ < (0, 4, 24):
expected = """
add 252 flops
sign 84 flops
sign 84 flops in f32[12,7], out f32[12,7]
cos 84 flops in f32[12,7], out f32[12,7]
add 84 flops in f32[12,7], f32[12,7], out f32[12,7]
""")
"""
else:
expected = """
add 252 flops
sign 84 flops in f32[12,7], out f32[12,7]
cos 84 flops in f32[12,7], out f32[12,7]
add 84 flops in f32[12,7], f32[12,7], out f32[12,7]
"""
self.assertContentsEqual(jaxpr_info.format_module(mod), expected)

def test_haiku_module(self):

Expand Down

0 comments on commit 0f4a96e

Please sign in to comment.