From f0663b2605c819ae5753c45e66a8b3648ca6353b Mon Sep 17 00:00:00 2001 From: paulbricman Date: Tue, 11 Jul 2023 13:40:18 +0200 Subject: [PATCH] fix: edit hints, comments mentioning jax.scan() to jax.lax.scan() --- docs/notebooks/transforms.ipynb | 4 ++-- haiku/_src/stateful.py | 2 +- haiku/_src/stateful_test.py | 2 +- haiku/_src/transform.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/notebooks/transforms.ipynb b/docs/notebooks/transforms.ipynb index 0b11f84f6..d7a9e6f2b 100644 --- a/docs/notebooks/transforms.ipynb +++ b/docs/notebooks/transforms.ipynb @@ -40,7 +40,7 @@ "source": [ "\n", "\n", - "Once a Haiku network has been transformed to a pair of pure functions using `hk.transform`, it's possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.scan` and so on.\n", + "Once a Haiku network has been transformed to a pair of pure functions using `hk.transform`, it's possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.lax.scan` and so on.\n", "\n", "If you want to use JAX transformations **inside** of a `hk.transform` however, you need to be more careful. It's possible, but most functions inside of the `hk.transform` boundary are still side effecting, and cannot safely be transformed by JAX.\n", "This is a common cause of `UnexpectedTracerError`s in code using Haiku. These errors are a result of using a JAX transform on a side effecting function (for more information on this JAX error, see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError).\n", @@ -93,7 +93,7 @@ "id": "pZrXKtC0C3lX" }, "source": [ - "These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.scan`, \n", + "These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.lax.scan`, \n", "`jax.while_loop`, ...).\n", "\n", "The error points to `hk.get_parameter`. This is the operation which makes `net` a side effecting function. The side effect in this case is the creation of a parameter, which gets stored into the Haiku state. Similarly you would get an error using `hk.next_rng_key`, because it advances the Haiku RNG state and stores a new PRNGKey into the Haiku state. In general, transforming a non-transformed Haiku module will result in an `UnexpectedTracerError`.\n", diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index ecd48af4e..383158c64 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -586,7 +586,7 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): """Equivalent to :func:`jax.lax.scan` but with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError("hk.scan() should not be used outside of hk.transform(). " - "Use jax.scan() instead.") + "Use jax.lax.scan() instead.") if length is None: length = jax.tree_util.tree_leaves(xs)[0].shape[0] diff --git a/haiku/_src/stateful_test.py b/haiku/_src/stateful_test.py index 0d757c7a6..e2e826df7 100644 --- a/haiku/_src/stateful_test.py +++ b/haiku/_src/stateful_test.py @@ -457,7 +457,7 @@ def test_difference_rng(self): def test_scan_no_transform(self): xs = jnp.arange(3) - with self.assertRaises(ValueError, msg="Use jax.scan() instead"): + with self.assertRaises(ValueError, msg="Use jax.lax.scan() instead"): stateful.scan(lambda c, x: (c, x), (), xs) @parameterized.parameters(0, 1, 2, 4, 8) diff --git a/haiku/_src/transform.py b/haiku/_src/transform.py index 45a296641..e5fdeac34 100644 --- a/haiku/_src/transform.py +++ b/haiku/_src/transform.py @@ -408,7 +408,7 @@ def transform_with_state(f) -> TransformedWithState: "An UnexpectedTracerError was raised while inside a Haiku transformed " "function (see error above).\n" "Hint: are you using a JAX transform or JAX control-flow function " - "(jax.vmap/jax.scan/...) inside a Haiku transform? You might want to use " + "(jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use " "the Haiku version of the transform instead (hk.vmap/hk.scan/...).\n" "See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html " "on why you can't use JAX transforms inside a Haiku module.")