Skip to content

Commit

Permalink
Merge pull request #690 from paulbricman:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 564985779
  • Loading branch information
copybara-github committed Sep 13, 2023
2 parents 840f441 + f0663b2 commit b1caab0
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/notebooks/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"source": [
"\n",
"\n",
"Once a Haiku network has been transformed to a pair of pure functions using `hk.transform`, it's possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.scan` and so on.\n",
"Once a Haiku network has been transformed to a pair of pure functions using `hk.transform`, it's possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.lax.scan` and so on.\n",
"\n",
"If you want to use JAX transformations **inside** of a `hk.transform` however, you need to be more careful. It's possible, but most functions inside of the `hk.transform` boundary are still side effecting, and cannot safely be transformed by JAX.\n",
"This is a common cause of `UnexpectedTracerError`s in code using Haiku. These errors are a result of using a JAX transform on a side effecting function (for more information on this JAX error, see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError).\n",
Expand Down Expand Up @@ -93,7 +93,7 @@
"id": "pZrXKtC0C3lX"
},
"source": [
"These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.scan`, \n",
"These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.lax.scan`, \n",
"`jax.while_loop`, ...).\n",
"\n",
"The error points to `hk.get_parameter`. This is the operation which makes `net` a side effecting function. The side effect in this case is the creation of a parameter, which gets stored into the Haiku state. Similarly you would get an error using `hk.next_rng_key`, because it advances the Haiku RNG state and stores a new PRNGKey into the Haiku state. In general, transforming a non-transformed Haiku module will result in an `UnexpectedTracerError`.\n",
Expand Down
2 changes: 1 addition & 1 deletion haiku/_src/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1):
"""Equivalent to :func:`jax.lax.scan` but with Haiku state passed in/out."""
if not base.inside_transform():
raise ValueError("hk.scan() should not be used outside of hk.transform(). "
"Use jax.scan() instead.")
"Use jax.lax.scan() instead.")

if length is None:
length = jax.tree_util.tree_leaves(xs)[0].shape[0]
Expand Down
2 changes: 1 addition & 1 deletion haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def test_difference_rng(self):

def test_scan_no_transform(self):
xs = jnp.arange(3)
with self.assertRaises(ValueError, msg="Use jax.scan() instead"):
with self.assertRaises(ValueError, msg="Use jax.lax.scan() instead"):
stateful.scan(lambda c, x: (c, x), (), xs)

@parameterized.parameters(0, 1, 2, 4, 8)
Expand Down
2 changes: 1 addition & 1 deletion haiku/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def transform_with_state(f) -> TransformedWithState:
"An UnexpectedTracerError was raised while inside a Haiku transformed "
"function (see error above).\n"
"Hint: are you using a JAX transform or JAX control-flow function "
"(jax.vmap/jax.scan/...) inside a Haiku transform? You might want to use "
"(jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use "
"the Haiku version of the transform instead (hk.vmap/hk.scan/...).\n"
"See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html "
"on why you can't use JAX transforms inside a Haiku module.")
Expand Down

0 comments on commit b1caab0

Please sign in to comment.