Skip to content

Commit

Permalink
chex: alias PRNGKey to jax.Array
Browse files Browse the repository at this point in the history
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict.

PiperOrigin-RevId: 565133147
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Sep 13, 2023
1 parent 533aeeb commit 6615676
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chex/_src/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
Scalar = Union[float, int]
Numeric = Union[Array, Scalar]
Shape = jax.core.Shape
PRNGKey = Union[jax.random.KeyArray, jax.Array]
PRNGKey = jax.Array
PyTreeDef = jax.tree_util.PyTreeDef
Device = jax.Device
ArrayDType = type(jnp.float32)

0 comments on commit 6615676

Please sign in to comment.