diff --git a/mt3/layers_test.py b/mt3/layers_test.py index 4d9310f..40ce63f 100644 --- a/mt3/layers_test.py +++ b/mt3/layers_test.py @@ -500,7 +500,7 @@ def test_mlp_same_out_dim(self): dtype=np.float32) params = module.init(random.PRNGKey(0), inputs, deterministic=True) self.assertEqual( - jax.tree_map(lambda a: a.tolist(), params), { + jax.tree.map(lambda a: a.tolist(), params), { 'params': { 'wi': { 'kernel': [[