From 7b4d245352f6ff84072d1ff005f81bcda9bbcc87 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 22 May 2024 09:21:02 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 636191717 --- mt3/layers_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mt3/layers_test.py b/mt3/layers_test.py index 4d9310f..40ce63f 100644 --- a/mt3/layers_test.py +++ b/mt3/layers_test.py @@ -500,7 +500,7 @@ def test_mlp_same_out_dim(self): dtype=np.float32) params = module.init(random.PRNGKey(0), inputs, deterministic=True) self.assertEqual( - jax.tree_map(lambda a: a.tolist(), params), { + jax.tree.map(lambda a: a.tolist(), params), { 'params': { 'wi': { 'kernel': [[