Skip to content

Commit

Permalink
Add projection_l1_sphere and projection_l1_ball.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686114713
  • Loading branch information
mblondel authored and OptaxDev committed Oct 18, 2024
1 parent 3d8f8be commit dd29fe1
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 2 deletions.
10 changes: 10 additions & 0 deletions docs/api/projections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Available projections
.. autosummary::
projection_box
projection_hypercube
projection_l1_ball
projection_l1_sphere
projection_non_negative
projection_simplex

Expand All @@ -44,6 +46,14 @@ Projection onto a hypercube
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_hypercube

Projection onto the L1 ball
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_l1_ball

Projection onto the L1 sphere
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_l1_sphere

Projection onto the non-negative orthant
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_non_negative
Expand Down
2 changes: 2 additions & 0 deletions optax/projections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@

from optax.projections._projections import projection_box
from optax.projections._projections import projection_hypercube
from optax.projections._projections import projection_l1_ball
from optax.projections._projections import projection_l1_sphere
from optax.projections._projections import projection_non_negative
from optax.projections._projections import projection_simplex
69 changes: 67 additions & 2 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import jax.numpy as jnp


from optax import tree_utils as otu


def projection_non_negative(pytree: Any) -> Any:
r"""Projection onto the non-negative orthant.
Expand Down Expand Up @@ -148,10 +151,10 @@ def projection_simplex(pytree: Any,
>>> from optax import tree_utils, projections
>>> pytree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_sum(pytree)
6.2
Array(6.2, dtype=float32)
>>> new_pytree = projections.projection_simplex(pytree)
>>> tree_utils.tree_sum(new_pytree)
1.0000002
Array(1.0000002, dtype=float32)
"""
if scale is None:
scale = 1.0
Expand All @@ -160,3 +163,65 @@ def projection_simplex(pytree: Any,
new_values = scale * _projection_unit_simplex(values / scale)

return unravel_fn(new_values)


def projection_l1_sphere(pytree: Any, scale: float = 1.0) -> Any:
r"""Projection onto the l1 sphere.
This function solves the following constrained optimization problem,
where ``x`` is the input pytree.
.. math::
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
||y||_1 = \text{scale}
Args:
pytree: pytree to project.
scale: radius of the sphere.
Returns:
projected pytree, with the same structure as ``pytree``.
"""
tree_abs = jax.tree.map(jnp.abs, pytree)
tree_sign = jax.tree.map(jnp.sign, pytree)
tree_abs_proj = projection_simplex(tree_abs, scale)
return otu.tree_mul(tree_sign, tree_abs_proj)


def projection_l1_ball(pytree: Any, scale: float = 1.0) -> Any:
r"""Projection onto the l1 ball.
This function solves the following constrained optimization problem,
where ``x`` is the input pytree.
.. math::
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
||y||_1 \le \text{scale}
Args:
pytree: pytree to project.
scale: radius of the ball.
Returns:
projected pytree, with the same structure as ``pytree``.
.. versionadded:: 0.2.4
Example:
>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> pytree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_l1_norm(pytree)
6.2
>>> new_pytree = projections.projection_l1_ball(pytree)
>>> tree_utils.tree_sum(new_pytree)
1.000002
"""
l1_norm = otu.tree_l1_norm(pytree)
return jax.lax.cond(l1_norm <= scale,
lambda pytree: pytree,
lambda pytree: projection_l1_sphere(pytree, scale),
operand=pytree)
38 changes: 38 additions & 0 deletions optax/projections/_projections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def projection_simplex_jacobian(projection):

class ProjectionsTest(parameterized.TestCase):

def setUp(self):
super().setUp()
array_1d = jnp.array([0.5, 2.1, -3.5])
array_2d = jnp.array([[0.5, 2.1, -3.5], [1.0, 2.0, 3.0]])
tree = (array_1d, array_1d)
self.data = dict(array_1d=array_1d, array_2d=array_2d, tree=tree)

def test_projection_non_negative(self):
with self.subTest('with an array'):
x = jnp.array([-1.0, 2.0, 3.0])
Expand Down Expand Up @@ -167,6 +174,37 @@ def test_projection_simplex_vmap(self, scale):
np.testing.assert_array_equal(True, 0 <= p)
np.testing.assert_array_equal(True, p <= scale)

@parameterized.product(data_key=['array_1d', 'array_2d', 'tree'],
scale=[1.0, 3.21])
def test_projection_l1_sphere(self, data_key, scale):
x = self.data[data_key]
p = proj.projection_l1_sphere(x, scale)
# Check that the projection has the correct l1 norm.
np.testing.assert_almost_equal(otu.tree_l1_norm(p), scale, decimal=4)

def _check_projection_ball(self, x, ball_proj, norm_fun):
"""Check correctness of the projection onto a ball."""
norm_value = norm_fun(x)

with self.subTest('Check when input is already in the ball'):
big_radius = norm_value * 2
p = ball_proj(x, big_radius)
np.testing.assert_array_almost_equal(x, p)

with self.subTest('Check when input is on the boundary of the ball'):
p = ball_proj(x, norm_value)
np.testing.assert_array_almost_equal(x, p)

with self.subTest('Check when input is outside the ball'):
small_radius = norm_value / 2
p = ball_proj(x, small_radius)
np.testing.assert_almost_equal(norm_fun(p), small_radius, decimal=4)

@parameterized.parameters('array_1d', 'array_2d', 'tree')
def test_projection_l1_ball(self, data_key):
x = self.data[data_key]
self._check_projection_ball(x, proj.projection_l1_ball, otu.tree_l1_norm)


if __name__ == '__main__':
absltest.main()

0 comments on commit dd29fe1

Please sign in to comment.