Skip to content

Commit

Permalink
Add tree_max.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689070753
  • Loading branch information
mblondel authored and OptaxDev committed Oct 23, 2024
1 parent 6da1711 commit ed1220e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Tree
tree_l1_norm
tree_l2_norm
tree_map_params
tree_max
tree_mul
tree_ones_like
tree_random_like
Expand Down Expand Up @@ -156,6 +157,10 @@ Tree map parameters
~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_map_params

Tree max
~~~~~~~~
.. autofunction:: tree_max

Tree multiply
~~~~~~~~~~~~~
.. autofunction:: tree_mul
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from optax.tree_utils._tree_math import tree_full_like
from optax.tree_utils._tree_math import tree_l1_norm
from optax.tree_utils._tree_math import tree_l2_norm
from optax.tree_utils._tree_math import tree_max
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
from optax.tree_utils._tree_math import tree_scalar_mul
Expand Down
14 changes: 14 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ def tree_sum(tree: Any) -> chex.Numeric:
return jax.tree.reduce(operator.add, sums, initializer=0)


def tree_max(tree: Any) -> chex.Numeric:
"""Compute the max of all the elements in a pytree.
Args:
tree: pytree.
Returns:
a scalar value.
"""
maxes = jax.tree.map(jnp.max, tree)
# initializer=-jnp.inf should work but pytype wants a jax.Array.
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))


def _square(leaf):
return jnp.square(leaf.real) + jnp.square(leaf.imag)

Expand Down
10 changes: 10 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def test_tree_sum(self):
got = tu.tree_sum(self.tree_a)
np.testing.assert_allclose(expected, got)

@parameterized.parameters(
'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict'
)
def test_tree_max(self, key):
tree = self.data[key]
values, _ = flatten_util.ravel_pytree(tree)
expected = jnp.max(values)
got = tu.tree_max(tree)
np.testing.assert_allclose(expected, got)

def test_tree_l2_norm(self):
expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real)
got = tu.tree_l2_norm(self.array_a)
Expand Down

0 comments on commit ed1220e

Please sign in to comment.