Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Fix incorrect type annotations.
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases. PiperOrigin-RevId: 556042678
- Loading branch information