diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 816f82ee..9b909431 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -100,6 +100,7 @@ Tree tree_l1_norm tree_l2_norm tree_map_params + tree_max tree_mul tree_ones_like tree_random_like @@ -156,6 +157,10 @@ Tree map parameters ~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_map_params +Tree max +~~~~~~~~ +.. autofunction:: tree_max + Tree multiply ~~~~~~~~~~~~~ .. autofunction:: tree_mul diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index d984a100..9675ad02 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -32,6 +32,7 @@ from optax.tree_utils._tree_math import tree_full_like from optax.tree_utils._tree_math import tree_l1_norm from optax.tree_utils._tree_math import tree_l2_norm +from optax.tree_utils._tree_math import tree_max from optax.tree_utils._tree_math import tree_mul from optax.tree_utils._tree_math import tree_ones_like from optax.tree_utils._tree_math import tree_scalar_mul diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 13673549..3d52c23c 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -168,6 +168,20 @@ def tree_sum(tree: Any) -> chex.Numeric: return jax.tree.reduce(operator.add, sums, initializer=0) +def tree_max(tree: Any) -> chex.Numeric: + """Compute the max of all the elements in a pytree. + + Args: + tree: pytree. + + Returns: + a scalar value. + """ + maxes = jax.tree.map(jnp.max, tree) + # initializer=-jnp.inf should work but pytype wants a jax.Array. + return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf)) + + def _square(leaf): return jnp.square(leaf.real) + jnp.square(leaf.imag) diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index f0f326ce..054f239d 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -133,6 +133,16 @@ def test_tree_sum(self): got = tu.tree_sum(self.tree_a) np.testing.assert_allclose(expected, got) + @parameterized.parameters( + 'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict' + ) + def test_tree_max(self, key): + tree = self.data[key] + values, _ = flatten_util.ravel_pytree(tree) + expected = jnp.max(values) + got = tu.tree_max(tree) + np.testing.assert_allclose(expected, got) + def test_tree_l2_norm(self): expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real) got = tu.tree_l2_norm(self.array_a)