diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index 28ec57f5..98ca33e6 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -519,7 +519,7 @@ def masked_negate(updates): with self.subTest('tree_map_params'): result = state_utils.tree_map_params(init_fn, lambda v: v, state) - chex.assert_tree_all_equal_structs(result, state) + chex.assert_trees_all_equal_structs(result, state) updates, state = update_fn(input_updates, state, params) chex.assert_trees_all_close(updates, correct_updates)