Skip to content

Commit

Permalink
Add tree_linf_norm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688548808
  • Loading branch information
mblondel authored and OptaxDev committed Oct 23, 2024
1 parent ed1220e commit 3d20701
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Tree
tree_get_all_with_path
tree_l1_norm
tree_l2_norm
tree_linf_norm
tree_map_params
tree_max
tree_mul
Expand Down Expand Up @@ -153,6 +154,10 @@ Tree l2 norm
~~~~~~~~~~~~
.. autofunction:: tree_l2_norm

Tree l-infinity norm
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_linf_norm

Tree map parameters
~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_map_params
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from optax.tree_utils._tree_math import tree_full_like
from optax.tree_utils._tree_math import tree_l1_norm
from optax.tree_utils._tree_math import tree_l2_norm
from optax.tree_utils._tree_math import tree_linf_norm
from optax.tree_utils._tree_math import tree_max
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
Expand Down
13 changes: 13 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,19 @@ def tree_l1_norm(tree: Any) -> chex.Numeric:
return tree_sum(abs_tree)


def tree_linf_norm(tree: Any) -> chex.Numeric:
"""Compute the l-infinity norm of a pytree.
Args:
tree: pytree.
Returns:
a scalar value.
"""
abs_tree = jax.tree.map(jnp.abs, tree)
return tree_max(abs_tree)


def tree_zeros_like(
tree: Any,
dtype: Optional[jax.typing.DTypeLike] = None,
Expand Down
10 changes: 10 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def test_tree_l1_norm(self, key):
got = tu.tree_l1_norm(tree)
np.testing.assert_allclose(expected, got, atol=1e-4)

@parameterized.parameters(
'tree_a', 'tree_a_dict', 'tree_b', 'array_a', 'array_b', 'tree_b_dict'
)
def test_tree_linf_norm(self, key):
tree = self.data[key]
values, _ = flatten_util.ravel_pytree(tree)
expected = jnp.max(jnp.abs(values))
got = tu.tree_linf_norm(tree)
np.testing.assert_allclose(expected, got, atol=1e-4)

def test_tree_zeros_like(self):
expected = jnp.zeros_like(self.array_a)
got = tu.tree_zeros_like(self.array_a)
Expand Down

0 comments on commit 3d20701

Please sign in to comment.