Skip to content

Commit

Permalink
Add projection_linf_ball.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689101490
  • Loading branch information
mblondel authored and OptaxDev committed Oct 23, 2024
1 parent 6355f32 commit b8c2e13
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/api/projections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Available projections
projection_l1_sphere
projection_l2_ball
projection_l2_sphere
projection_linf_ball
projection_non_negative
projection_simplex

Expand Down Expand Up @@ -64,6 +65,10 @@ Projection onto the L2 sphere
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_l2_sphere

Projection onto the L-infinity ball
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_linf_ball

Projection onto the non-negative orthant
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_non_negative
Expand Down
1 change: 1 addition & 0 deletions optax/projections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from optax.projections._projections import projection_l1_sphere
from optax.projections._projections import projection_l2_ball
from optax.projections._projections import projection_l2_sphere
from optax.projections._projections import projection_linf_ball
from optax.projections._projections import projection_non_negative
from optax.projections._projections import projection_simplex
23 changes: 23 additions & 0 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,26 @@ def projection_l2_ball(pytree: Any, scale: float = 1.0) -> Any:
lambda pytree: pytree,
lambda pytree: otu.tree_scalar_mul(factor, pytree),
operand=pytree)


def projection_linf_ball(pytree: Any, scale: float = 1.0) -> Any:
r"""Projection onto the l-infinity ball.
This function solves the following constrained optimization problem,
where ``x`` is the input pytree.
.. math::
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
||y||_{\infty} \le \text{scale}
Args:
pytree: pytree to project.
scale: radius of the ball.
Returns:
projected pytree, with the same structure as ``pytree``.
"""
lower_tree = otu.tree_full_like(pytree, -scale)
upper_tree = otu.tree_full_like(pytree, scale)
return projection_box(pytree, lower=lower_tree, upper=upper_tree)
3 changes: 2 additions & 1 deletion optax/projections/_projections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def setUp(self):
self.fns = dict(
l1=(proj.projection_l1_ball, otu.tree_l1_norm),
l2=(proj.projection_l2_ball, otu.tree_l2_norm),
linf=(proj.projection_linf_ball, otu.tree_linf_norm),
)

def test_projection_non_negative(self):
Expand Down Expand Up @@ -196,7 +197,7 @@ def test_projection_l2_sphere(self, data_key, scale):

@parameterized.product(
data_key=['array_1d', 'array_2d', 'tree'],
norm=['l1', 'l2'],
norm=['l1', 'l2', 'linf'],
)
def test_projection_ball(self, data_key, norm):
"""Check correctness of the projection onto a ball."""
Expand Down

0 comments on commit b8c2e13

Please sign in to comment.