From 1f79d80de5698c57db384bbb1b52fdd60c603c8a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 16 Nov 2023 12:59:49 -0800 Subject: [PATCH] fix test not to depend on details of JAX shape error type/message PiperOrigin-RevId: 583136866 --- haiku/_src/nets/vqvae_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haiku/_src/nets/vqvae_test.py b/haiku/_src/nets/vqvae_test.py index b15f567d2..d24b121f1 100644 --- a/haiku/_src/nets/vqvae_test.py +++ b/haiku/_src/nets/vqvae_test.py @@ -91,7 +91,7 @@ def testConstruct(self, constructor, kwargs): def testShapeChecking(self, constructor, kwargs): vqvae_module = constructor(**kwargs) wrong_shape_input = np.random.randn(100, kwargs['embedding_dim'] * 2) - with self.assertRaisesRegex(TypeError, 'total size must be unchanged'): + with self.assertRaises((ValueError, TypeError)): vqvae_module( jnp.array(wrong_shape_input.astype(np.float32)), is_training=False)