From 259e20fbeb77b612c247834a85f844a592fa65c6 Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Tue, 22 Oct 2024 08:51:39 -0700 Subject: [PATCH] Add projection_linf_ball. Also fix assertions that were incorrect in the linf norm case. PiperOrigin-RevId: 688567383 --- docs/api/projections.rst | 15 ++++++ docs/api/utilities.rst | 5 ++ optax/projections/__init__.py | 3 ++ optax/projections/_projections.py | 75 ++++++++++++++++++++++++++ optax/projections/_projections_test.py | 61 +++++++++++++++------ optax/tree_utils/__init__.py | 2 + optax/tree_utils/_tree_math.py | 27 ++++++++++ optax/tree_utils/_tree_math_test.py | 19 +++++++ 8 files changed, 191 insertions(+), 16 deletions(-) diff --git a/docs/api/projections.rst b/docs/api/projections.rst index 683fde54..556377c2 100644 --- a/docs/api/projections.rst +++ b/docs/api/projections.rst @@ -35,6 +35,9 @@ Available projections projection_hypercube projection_l1_ball projection_l1_sphere + projection_l2_ball + projection_l2_sphere + projection_linf_ball projection_non_negative projection_simplex @@ -54,6 +57,18 @@ Projection onto the L1 sphere ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: projection_l1_sphere +Projection onto the L2 ball +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: projection_l2_ball + +Projection onto the L2 sphere +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: projection_l2_sphere + +Projection onto the L-infinity ball +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: projection_linf_ball + Projection onto the non-negative orthant ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: projection_non_negative diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 816f82ee..9b909431 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -100,6 +100,7 @@ Tree tree_l1_norm tree_l2_norm tree_map_params + tree_max tree_mul tree_ones_like tree_random_like @@ -156,6 +157,10 @@ Tree map parameters ~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_map_params +Tree max +~~~~~~~~ +.. autofunction:: tree_max + Tree multiply ~~~~~~~~~~~~~ .. autofunction:: tree_mul diff --git a/optax/projections/__init__.py b/optax/projections/__init__.py index 2d8646bb..19171ce6 100644 --- a/optax/projections/__init__.py +++ b/optax/projections/__init__.py @@ -19,5 +19,8 @@ from optax.projections._projections import projection_hypercube from optax.projections._projections import projection_l1_ball from optax.projections._projections import projection_l1_sphere +from optax.projections._projections import projection_l2_ball +from optax.projections._projections import projection_l2_sphere +from optax.projections._projections import projection_linf_ball from optax.projections._projections import projection_non_negative from optax.projections._projections import projection_simplex diff --git a/optax/projections/_projections.py b/optax/projections/_projections.py index 1d799e5f..1c34f879 100644 --- a/optax/projections/_projections.py +++ b/optax/projections/_projections.py @@ -225,3 +225,78 @@ def projection_l1_ball(pytree: Any, scale: float = 1.0) -> Any: lambda pytree: pytree, lambda pytree: projection_l1_sphere(pytree, scale), operand=pytree) + + +def projection_l2_sphere(pytree: Any, scale: float = 1.0) -> Any: + r"""Projection onto the l2 sphere. + + This function solves the following constrained optimization problem, + where ``x`` is the input pytree. + + .. math:: + + \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad + ||y||_2 = \text{value} + + Args: + pytree: pytree to project. + scale: radius of the sphere. + + Returns: + projected pytree, with the same structure as ``pytree``. + + .. versionadded:: 0.2.4 + """ + factor = scale / otu.tree_l2_norm(pytree) + return otu.tree_scalar_mul(factor, pytree) + + +def projection_l2_ball(pytree: Any, scale: float = 1.0) -> Any: + r"""Projection onto the l2 ball. + + This function solves the following constrained optimization problem, + where ``x`` is the input pytree. + + .. math:: + + \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad + ||y||_2 \le \text{scale} + + Args: + pytree: pytree to project. + scale: radius of the ball. + + Returns: + projected pytree, with the same structure as ``pytree``. + + .. versionadded:: 0.2.4 + """ + l2_norm = otu.tree_l2_norm(pytree) + factor = scale / l2_norm + return jax.lax.cond(l2_norm <= scale, + lambda pytree: pytree, + lambda pytree: otu.tree_scalar_mul(factor, pytree), + operand=pytree) + + +def projection_linf_ball(pytree: Any, scale: float = 1.0) -> Any: + r"""Projection onto the l-infinity ball. + + This function solves the following constrained optimization problem, + where ``x`` is the input pytree. + + .. math:: + + \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad + ||y||_{\infty} \le \text{scale} + + Args: + pytree: pytree to project. + scale: radius of the ball. + + Returns: + projected pytree, with the same structure as ``pytree``. + """ + lower_tree = otu.tree_full_like(pytree, -scale) + upper_tree = otu.tree_full_like(pytree, scale) + return projection_box(pytree, lower=lower_tree, upper=upper_tree) diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index e4bd17c7..c39b3dad 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -174,37 +174,66 @@ def test_projection_simplex_vmap(self, scale): np.testing.assert_array_equal(True, 0 <= p) np.testing.assert_array_equal(True, p <= scale) - @parameterized.product(data_key=['array_1d', 'array_2d', 'tree'], - scale=[1.0, 3.21]) - def test_projection_l1_sphere(self, data_key, scale): + def _test_projection_sphere(self, data_key, scale, proj_fun, norm_fun): + """Check correctness of the projection onto a sphere.""" x = self.data[data_key] - p = proj.projection_l1_sphere(x, scale) - # Check that the projection has the correct l1 norm. - np.testing.assert_almost_equal(otu.tree_l1_norm(p), scale, decimal=4) + p = proj_fun(x, scale) + np.testing.assert_almost_equal(norm_fun(p), scale, decimal=4) + + @parameterized.product( + data_key=['array_1d', 'array_2d', 'tree'], scale=[1.0, 3.21] + ) + def test_projection_l1_sphere(self, data_key, scale): + self._test_projection_sphere( + data_key, scale, proj.projection_l1_sphere, otu.tree_l1_norm + ) - def _check_projection_ball(self, x, ball_proj, norm_fun): + @parameterized.product( + data_key=['array_1d', 'array_2d', 'tree'], scale=[1.0, 3.21] + ) + def test_projection_l2_sphere(self, data_key, scale): + self._test_projection_sphere( + data_key, scale, proj.projection_l2_sphere, otu.tree_l2_norm + ) + + def _test_projection_ball(self, data_key, proj_fun, norm_fun): """Check correctness of the projection onto a ball.""" + x = self.data[data_key] + eps = 1e-4 + norm_value = norm_fun(x) with self.subTest('Check when input is already in the ball'): big_radius = norm_value * 2 - p = ball_proj(x, big_radius) - np.testing.assert_array_almost_equal(x, p) + p = proj_fun(x, big_radius) + self.assertLessEqual(norm_fun(p), big_radius + eps) with self.subTest('Check when input is on the boundary of the ball'): - p = ball_proj(x, norm_value) - np.testing.assert_array_almost_equal(x, p) + p = proj_fun(x, norm_value) + self.assertLessEqual(norm_fun(p), norm_value + eps) with self.subTest('Check when input is outside the ball'): small_radius = norm_value / 2 - p = ball_proj(x, small_radius) - np.testing.assert_almost_equal(norm_fun(p), small_radius, decimal=4) + p = proj_fun(x, small_radius) + self.assertLessEqual(norm_fun(p), small_radius + eps) - @parameterized.parameters('array_1d', 'array_2d', 'tree') + @parameterized.product(data_key=['array_1d', 'array_2d', 'tree']) def test_projection_l1_ball(self, data_key): - x = self.data[data_key] - self._check_projection_ball(x, proj.projection_l1_ball, otu.tree_l1_norm) + self._test_projection_ball( + data_key, proj.projection_l1_ball, otu.tree_l1_norm + ) + @parameterized.product(data_key=['array_1d', 'array_2d', 'tree']) + def test_projection_l2_ball(self, data_key): + self._test_projection_ball( + data_key, proj.projection_l2_ball, otu.tree_l2_norm + ) + + @parameterized.product(data_key=['array_1d', 'array_2d', 'tree']) + def test_projection_linf_ball(self, data_key): + self._test_projection_ball( + data_key, proj.projection_linf_ball, otu.tree_linf_norm + ) if __name__ == '__main__': absltest.main() diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index d984a100..0887a9ab 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -32,6 +32,8 @@ from optax.tree_utils._tree_math import tree_full_like from optax.tree_utils._tree_math import tree_l1_norm from optax.tree_utils._tree_math import tree_l2_norm +from optax.tree_utils._tree_math import tree_linf_norm +from optax.tree_utils._tree_math import tree_max from optax.tree_utils._tree_math import tree_mul from optax.tree_utils._tree_math import tree_ones_like from optax.tree_utils._tree_math import tree_scalar_mul diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 13673549..1714663e 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -168,6 +168,20 @@ def tree_sum(tree: Any) -> chex.Numeric: return jax.tree.reduce(operator.add, sums, initializer=0) +def tree_max(tree: Any) -> chex.Numeric: + """Compute the max of all the elements in a pytree. + + Args: + tree: pytree. + + Returns: + a scalar value. + """ + maxes = jax.tree.map(jnp.max, tree) + # initializer=-jnp.inf should work but pytype doesn't like it. + return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf)) + + def _square(leaf): return jnp.square(leaf.real) + jnp.square(leaf.imag) @@ -203,6 +217,19 @@ def tree_l1_norm(tree: Any) -> chex.Numeric: return tree_sum(abs_tree) +def tree_linf_norm(tree: Any) -> chex.Numeric: + """Compute the l-infinity norm of a pytree. + + Args: + tree: pytree. + + Returns: + a scalar value. + """ + abs_tree = jax.tree.map(jnp.abs, tree) + return tree_max(abs_tree) + + def tree_zeros_like( tree: Any, dtype: Optional[jax.typing.DTypeLike] = None, diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index f0f326ce..989c2ec7 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -133,6 +133,15 @@ def test_tree_sum(self): got = tu.tree_sum(self.tree_a) np.testing.assert_allclose(expected, got) + def test_tree_max(self): + expected = jnp.max(self.array_a) + got = tu.tree_max(self.array_a) + np.testing.assert_allclose(expected, got) + + expected = max(jnp.max(self.tree_a[0]), jnp.max(self.tree_a[1])) + got = tu.tree_max(self.tree_a) + np.testing.assert_allclose(expected, got) + def test_tree_l2_norm(self): expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real) got = tu.tree_l2_norm(self.array_a) @@ -153,6 +162,16 @@ def test_tree_l1_norm(self, key): got = tu.tree_l1_norm(tree) np.testing.assert_allclose(expected, got, atol=1e-4) + @parameterized.parameters( + 'tree_a', 'tree_a_dict', 'tree_b', 'array_a', 'array_b', 'tree_b_dict' + ) + def test_tree_linf_norm(self, key): + tree = self.data[key] + values, _ = flatten_util.ravel_pytree(tree) + expected = jnp.max(jnp.abs(values)) + got = tu.tree_linf_norm(tree) + np.testing.assert_allclose(expected, got, atol=1e-4) + def test_tree_zeros_like(self): expected = jnp.zeros_like(self.array_a) got = tu.tree_zeros_like(self.array_a)