diff --git a/haiku/_src/base.py b/haiku/_src/base.py index cc492b198..1c2a2ed05 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -683,6 +683,8 @@ def get_parameter( except MissingRNGError as e: if frame.params_frozen: raise param_missing_error from e + else: + raise e if param is DO_NOT_STORE: # Initializers or custom creators that return `DO_NOT_STORE` are required diff --git a/haiku/_src/base_test.py b/haiku/_src/base_test.py index 082ee6563..8d69fdaee 100644 --- a/haiku/_src/base_test.py +++ b/haiku/_src/base_test.py @@ -137,6 +137,15 @@ def test_parameter_in_immutable_ctx(self, params): ValueError, "parameters must be created as part of `init`"): base.get_parameter("w", [], init=jnp.zeros) + def test_get_parameter_rng_exception(self): + with base.new_context(): + with self.assertRaisesRegex( + base.MissingRNGError, "pass a non-None PRNGKey to init" + ): + base.get_parameter( + "w", [], init=lambda shape, dtype: base.next_rng_key() + ) + def test_get_parameter_wrong_shape(self): with base.new_context(): with self.assertRaisesRegex(ValueError, "does not match shape"):