Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace deprecated
jax.tree_*
functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 636191717
- Loading branch information