diff --git a/docs/api/projections.rst b/docs/api/projections.rst index 8d38166b4..683fde542 100644 --- a/docs/api/projections.rst +++ b/docs/api/projections.rst @@ -33,6 +33,8 @@ Available projections .. autosummary:: projection_box projection_hypercube + projection_l1_ball + projection_l1_sphere projection_non_negative projection_simplex @@ -44,6 +46,14 @@ Projection onto a hypercube ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: projection_hypercube +Projection onto the L1 ball +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: projection_l1_ball + +Projection onto the L1 sphere +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: projection_l1_sphere + Projection onto the non-negative orthant ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: projection_non_negative diff --git a/optax/projections/__init__.py b/optax/projections/__init__.py index 70fbd949b..2d8646bb2 100644 --- a/optax/projections/__init__.py +++ b/optax/projections/__init__.py @@ -17,5 +17,7 @@ from optax.projections._projections import projection_box from optax.projections._projections import projection_hypercube +from optax.projections._projections import projection_l1_ball +from optax.projections._projections import projection_l1_sphere from optax.projections._projections import projection_non_negative from optax.projections._projections import projection_simplex diff --git a/optax/projections/_projections.py b/optax/projections/_projections.py index 107a7dbea..ec723f457 100644 --- a/optax/projections/_projections.py +++ b/optax/projections/_projections.py @@ -24,6 +24,9 @@ import jax.numpy as jnp +from optax import tree_utils as otu + + def projection_non_negative(pytree: Any) -> Any: r"""Projection onto the non-negative orthant. @@ -148,10 +151,10 @@ def projection_simplex(pytree: Any, >>> from optax import tree_utils, projections >>> pytree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> tree_utils.tree_sum(pytree) - 6.2 + Array(6.2, dtype=float32) >>> new_pytree = projections.projection_simplex(pytree) >>> tree_utils.tree_sum(new_pytree) - 1.0000002 + Array(1.0000002, dtype=float32) """ if scale is None: scale = 1.0 @@ -160,3 +163,65 @@ def projection_simplex(pytree: Any, new_values = scale * _projection_unit_simplex(values / scale) return unravel_fn(new_values) + + +def projection_l1_sphere(pytree: Any, scale: float = 1.0) -> Any: + r"""Projection onto the l1 sphere. + + This function solves the following constrained optimization problem, + where ``x`` is the input pytree. + + .. math:: + + \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad + ||y||_1 = \text{scale} + + Args: + pytree: pytree to project. + scale: radius of the sphere. + + Returns: + projected pytree, with the same structure as ``pytree``. + """ + tree_abs = jax.tree.map(jnp.abs, pytree) + tree_sign = jax.tree.map(jnp.sign, pytree) + tree_abs_proj = projection_simplex(tree_abs, scale) + return otu.tree_mul(tree_sign, tree_abs_proj) + + +def projection_l1_ball(pytree: Any, scale: float = 1.0) -> Any: + r"""Projection onto the l1 ball. + + This function solves the following constrained optimization problem, + where ``x`` is the input pytree. + + .. math:: + + \underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad + ||y||_1 \le \text{scale} + + Args: + pytree: pytree to project. + scale: radius of the ball. + + Returns: + projected pytree, with the same structure as ``pytree``. + + .. versionadded:: 0.2.4 + + Example: + + >>> import jax.numpy as jnp + >>> from optax import tree_utils, projections + >>> pytree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} + >>> tree_utils.tree_l1_norm(pytree) + 6.2 + >>> new_pytree = projections.projection_l1_ball(pytree) + >>> tree_utils.tree_sum(new_pytree) + 1.000002 + """ + l1_norm = otu.tree_l1_norm(pytree) + return jax.lax.cond(l1_norm <= scale, + lambda pytree: pytree, + lambda pytree: projection_l1_sphere(pytree, scale), + operand=pytree) diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index b71e28d1a..e4bd17c76 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -34,6 +34,13 @@ def projection_simplex_jacobian(projection): class ProjectionsTest(parameterized.TestCase): + def setUp(self): + super().setUp() + array_1d = jnp.array([0.5, 2.1, -3.5]) + array_2d = jnp.array([[0.5, 2.1, -3.5], [1.0, 2.0, 3.0]]) + tree = (array_1d, array_1d) + self.data = dict(array_1d=array_1d, array_2d=array_2d, tree=tree) + def test_projection_non_negative(self): with self.subTest('with an array'): x = jnp.array([-1.0, 2.0, 3.0]) @@ -167,6 +174,37 @@ def test_projection_simplex_vmap(self, scale): np.testing.assert_array_equal(True, 0 <= p) np.testing.assert_array_equal(True, p <= scale) + @parameterized.product(data_key=['array_1d', 'array_2d', 'tree'], + scale=[1.0, 3.21]) + def test_projection_l1_sphere(self, data_key, scale): + x = self.data[data_key] + p = proj.projection_l1_sphere(x, scale) + # Check that the projection has the correct l1 norm. + np.testing.assert_almost_equal(otu.tree_l1_norm(p), scale, decimal=4) + + def _check_projection_ball(self, x, ball_proj, norm_fun): + """Check correctness of the projection onto a ball.""" + norm_value = norm_fun(x) + + with self.subTest('Check when input is already in the ball'): + big_radius = norm_value * 2 + p = ball_proj(x, big_radius) + np.testing.assert_array_almost_equal(x, p) + + with self.subTest('Check when input is on the boundary of the ball'): + p = ball_proj(x, norm_value) + np.testing.assert_array_almost_equal(x, p) + + with self.subTest('Check when input is outside the ball'): + small_radius = norm_value / 2 + p = ball_proj(x, small_radius) + np.testing.assert_almost_equal(norm_fun(p), small_radius, decimal=4) + + @parameterized.parameters('array_1d', 'array_2d', 'tree') + def test_projection_l1_ball(self, data_key): + x = self.data[data_key] + self._check_projection_ball(x, proj.projection_l1_ball, otu.tree_l1_norm) + if __name__ == '__main__': absltest.main()