diff --git a/docs/api/projections.rst b/docs/api/projections.rst index 243f848b..556377c2 100644 --- a/docs/api/projections.rst +++ b/docs/api/projections.rst @@ -37,6 +37,7 @@ Available projections projection_l1_sphere projection_l2_ball projection_l2_sphere + projection_linf_ball projection_non_negative projection_simplex @@ -64,6 +65,10 @@ Projection onto the L2 sphere ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: projection_l2_sphere +Projection onto the L-infinity ball +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: projection_linf_ball + Projection onto the non-negative orthant ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: projection_non_negative diff --git a/optax/projections/__init__.py b/optax/projections/__init__.py index 6e6e45e5..19171ce6 100644 --- a/optax/projections/__init__.py +++ b/optax/projections/__init__.py @@ -21,5 +21,6 @@ from optax.projections._projections import projection_l1_sphere from optax.projections._projections import projection_l2_ball from optax.projections._projections import projection_l2_sphere +from optax.projections._projections import projection_linf_ball from optax.projections._projections import projection_non_negative from optax.projections._projections import projection_simplex diff --git a/optax/projections/_projections.py b/optax/projections/_projections.py index 3b70976d..1c34f879 100644 --- a/optax/projections/_projections.py +++ b/optax/projections/_projections.py @@ -277,3 +277,26 @@ def projection_l2_ball(pytree: Any, scale: float = 1.0) -> Any: lambda pytree: pytree, lambda pytree: otu.tree_scalar_mul(factor, pytree), operand=pytree) + + +def projection_linf_ball(pytree: Any, scale: float = 1.0) -> Any: + r"""Projection onto the l-infinity ball. + + This function solves the following constrained optimization problem, + where ``x`` is the input pytree. + + .. math:: + + \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad + ||y||_{\infty} \le \text{scale} + + Args: + pytree: pytree to project. + scale: radius of the ball. + + Returns: + projected pytree, with the same structure as ``pytree``. + """ + lower_tree = otu.tree_full_like(pytree, -scale) + upper_tree = otu.tree_full_like(pytree, scale) + return projection_box(pytree, lower=lower_tree, upper=upper_tree) diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index 23b05b38..f01185be 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -43,6 +43,7 @@ def setUp(self): self.fns = dict( l1=(proj.projection_l1_ball, otu.tree_l1_norm), l2=(proj.projection_l2_ball, otu.tree_l2_norm), + linf=(proj.projection_linf_ball, otu.tree_linf_norm), ) def test_projection_non_negative(self): @@ -196,7 +197,7 @@ def test_projection_l2_sphere(self, data_key, scale): @parameterized.product( data_key=['array_1d', 'array_2d', 'tree'], - norm=['l1', 'l2'], + norm=['l1', 'l2', 'linf'], ) def test_projection_ball(self, data_key, norm): """Check correctness of the projection onto a ball."""