Skip to content

Commit

Permalink
Relax error message in statful_tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649072487
  • Loading branch information
Cristian Garcia authored and copybara-github committed Jul 3, 2024
1 parent a7b7e73 commit 66d8b74
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,7 @@ def creates_params(_):
def test_vmap_split_rng_out_axes_error_no_split_rng(self):
f = stateful.vmap(lambda x: x, split_rng=False, out_axes=None)
x = jnp.arange(4)
with self.assertRaisesRegex(ValueError,
"vmap has mapped output but out_axes is None"):
with self.assertRaisesRegex(ValueError, ".*vmap.*out_axes.*None.*"):
# test our split_rng error does not clobber jax error message.
f(x)

Expand All @@ -696,8 +695,7 @@ def g(x):
f(x)

x = jnp.arange(4)
with self.assertRaisesRegex(ValueError,
"vmap has mapped output but out_axes is None"):
with self.assertRaisesRegex(ValueError, ".*vmap.*out_axes.*None.*"):
# test our split_rng error does not clobber jax error message.
g.apply({}, jax.random.PRNGKey(42), x)

Expand Down

0 comments on commit 66d8b74

Please sign in to comment.