diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 9b909431..f404113f 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -99,6 +99,7 @@ Tree tree_get_all_with_path tree_l1_norm tree_l2_norm + tree_linf_norm tree_map_params tree_max tree_mul @@ -153,6 +154,10 @@ Tree l2 norm ~~~~~~~~~~~~ .. autofunction:: tree_l2_norm +Tree l-infinity norm +~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_linf_norm + Tree map parameters ~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_map_params diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 9675ad02..0887a9ab 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -32,6 +32,7 @@ from optax.tree_utils._tree_math import tree_full_like from optax.tree_utils._tree_math import tree_l1_norm from optax.tree_utils._tree_math import tree_l2_norm +from optax.tree_utils._tree_math import tree_linf_norm from optax.tree_utils._tree_math import tree_max from optax.tree_utils._tree_math import tree_mul from optax.tree_utils._tree_math import tree_ones_like diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 3d52c23c..ca969141 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -217,6 +217,19 @@ def tree_l1_norm(tree: Any) -> chex.Numeric: return tree_sum(abs_tree) +def tree_linf_norm(tree: Any) -> chex.Numeric: + """Compute the l-infinity norm of a pytree. + + Args: + tree: pytree. + + Returns: + a scalar value. + """ + abs_tree = jax.tree.map(jnp.abs, tree) + return tree_max(abs_tree) + + def tree_zeros_like( tree: Any, dtype: Optional[jax.typing.DTypeLike] = None, diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index 054f239d..3feb4e13 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -163,6 +163,16 @@ def test_tree_l1_norm(self, key): got = tu.tree_l1_norm(tree) np.testing.assert_allclose(expected, got, atol=1e-4) + @parameterized.parameters( + 'tree_a', 'tree_a_dict', 'tree_b', 'array_a', 'array_b', 'tree_b_dict' + ) + def test_tree_linf_norm(self, key): + tree = self.data[key] + values, _ = flatten_util.ravel_pytree(tree) + expected = jnp.max(jnp.abs(values)) + got = tu.tree_linf_norm(tree) + np.testing.assert_allclose(expected, got, atol=1e-4) + def test_tree_zeros_like(self): expected = jnp.zeros_like(self.array_a) got = tu.tree_zeros_like(self.array_a)