Skip to content

Commit

Permalink
fix test not to depend on details of JAX shape error type/message
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583136866
  • Loading branch information
mattjj authored and copybara-github committed Nov 16, 2023
1 parent ce05309 commit 1f79d80
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion haiku/_src/nets/vqvae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def testConstruct(self, constructor, kwargs):
def testShapeChecking(self, constructor, kwargs):
vqvae_module = constructor(**kwargs)
wrong_shape_input = np.random.randn(100, kwargs['embedding_dim'] * 2)
with self.assertRaisesRegex(TypeError, 'total size must be unchanged'):
with self.assertRaises((ValueError, TypeError)):
vqvae_module(
jnp.array(wrong_shape_input.astype(np.float32)), is_training=False)

Expand Down

0 comments on commit 1f79d80

Please sign in to comment.