Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 636191717
  • Loading branch information
Jake VanderPlas authored and Magenta Team committed May 22, 2024
1 parent 2f8e2c8 commit 7b4d245
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion mt3/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def test_mlp_same_out_dim(self):
dtype=np.float32)
params = module.init(random.PRNGKey(0), inputs, deterministic=True)
self.assertEqual(
jax.tree_map(lambda a: a.tolist(), params), {
jax.tree.map(lambda a: a.tolist(), params), {
'params': {
'wi': {
'kernel': [[
Expand Down

0 comments on commit 7b4d245

Please sign in to comment.