From a11ec435fef5652cfeed52eb6fd29ae62e705f7d Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Wed, 26 Jul 2023 10:56:28 -0700 Subject: [PATCH] Don't swallow errors with RNG keys in parameter initialisers. Given the following code snippet: ```python def f(x): x = hk.Conv2D(3, 1)(x) return x f = hk.transform(f) x = jnp.ones([1, 28, 28]) f.init(None, x) ``` At HEAD we fail with a confusing error: ValueError: Parameters cannot be `None`. After this change we fail with a far more useful: MissingRNGError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers. PiperOrigin-RevId: 551256122 --- haiku/_src/base.py | 2 ++ haiku/_src/base_test.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/haiku/_src/base.py b/haiku/_src/base.py index cc492b198..1c2a2ed05 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -683,6 +683,8 @@ def get_parameter( except MissingRNGError as e: if frame.params_frozen: raise param_missing_error from e + else: + raise e if param is DO_NOT_STORE: # Initializers or custom creators that return `DO_NOT_STORE` are required diff --git a/haiku/_src/base_test.py b/haiku/_src/base_test.py index 082ee6563..8d69fdaee 100644 --- a/haiku/_src/base_test.py +++ b/haiku/_src/base_test.py @@ -137,6 +137,15 @@ def test_parameter_in_immutable_ctx(self, params): ValueError, "parameters must be created as part of `init`"): base.get_parameter("w", [], init=jnp.zeros) + def test_get_parameter_rng_exception(self): + with base.new_context(): + with self.assertRaisesRegex( + base.MissingRNGError, "pass a non-None PRNGKey to init" + ): + base.get_parameter( + "w", [], init=lambda shape, dtype: base.next_rng_key() + ) + def test_get_parameter_wrong_shape(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "does not match shape"):