Skip to content

Commit

Permalink
Don't swallow errors with RNG keys in parameter initialisers.
Browse files Browse the repository at this point in the history
Given the following code snippet:

```python
def f(x):
  x = hk.Conv2D(3, 1)(x)
  return x

f = hk.transform(f)
x = jnp.ones([1, 28, 28])
f.init(None, x)
```

At HEAD we fail with a confusing error:

    ValueError: Parameters cannot be `None`.

After this change we fail with a far more useful:

    MissingRNGError: You must pass a non-None PRNGKey to init and/or apply if you
    make use of random numbers.

PiperOrigin-RevId: 551458310
  • Loading branch information
tomhennigan authored and copybara-github committed Jul 27, 2023
1 parent c22867d commit 6526bc8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ def get_parameter(
except MissingRNGError as e:
if frame.params_frozen:
raise param_missing_error from e
else:
raise e

if param is DO_NOT_STORE:
# Initializers or custom creators that return `DO_NOT_STORE` are required
Expand Down
9 changes: 9 additions & 0 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ def test_parameter_in_immutable_ctx(self, params):
ValueError, "parameters must be created as part of `init`"):
base.get_parameter("w", [], init=jnp.zeros)

def test_get_parameter_rng_exception(self):
with base.new_context():
with self.assertRaisesRegex(
base.MissingRNGError, "pass a non-None PRNGKey to init"
):
base.get_parameter(
"w", [], init=lambda shape, dtype: base.next_rng_key()
)

def test_get_parameter_wrong_shape(self):
with base.new_context():
with self.assertRaisesRegex(ValueError, "does not match shape"):
Expand Down

0 comments on commit 6526bc8

Please sign in to comment.