Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 566933252
- Loading branch information