Skip to content

Commit

Permalink
Add projection_linf_ball.
Browse files Browse the repository at this point in the history
Also fix assertions that were incorrect in the linf norm case.

PiperOrigin-RevId: 688567383
  • Loading branch information
mblondel authored and OptaxDev committed Oct 22, 2024
1 parent 85378ad commit 259e20f
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 16 deletions.
15 changes: 15 additions & 0 deletions docs/api/projections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Available projections
projection_hypercube
projection_l1_ball
projection_l1_sphere
projection_l2_ball
projection_l2_sphere
projection_linf_ball
projection_non_negative
projection_simplex

Expand All @@ -54,6 +57,18 @@ Projection onto the L1 sphere
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_l1_sphere

Projection onto the L2 ball
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_l2_ball

Projection onto the L2 sphere
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_l2_sphere

Projection onto the L-infinity ball
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_linf_ball

Projection onto the non-negative orthant
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_non_negative
Expand Down
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Tree
tree_l1_norm
tree_l2_norm
tree_map_params
tree_max
tree_mul
tree_ones_like
tree_random_like
Expand Down Expand Up @@ -156,6 +157,10 @@ Tree map parameters
~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_map_params

Tree max
~~~~~~~~
.. autofunction:: tree_max

Tree multiply
~~~~~~~~~~~~~
.. autofunction:: tree_mul
Expand Down
3 changes: 3 additions & 0 deletions optax/projections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@
from optax.projections._projections import projection_hypercube
from optax.projections._projections import projection_l1_ball
from optax.projections._projections import projection_l1_sphere
from optax.projections._projections import projection_l2_ball
from optax.projections._projections import projection_l2_sphere
from optax.projections._projections import projection_linf_ball
from optax.projections._projections import projection_non_negative
from optax.projections._projections import projection_simplex
75 changes: 75 additions & 0 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,78 @@ def projection_l1_ball(pytree: Any, scale: float = 1.0) -> Any:
lambda pytree: pytree,
lambda pytree: projection_l1_sphere(pytree, scale),
operand=pytree)


def projection_l2_sphere(pytree: Any, scale: float = 1.0) -> Any:
r"""Projection onto the l2 sphere.
This function solves the following constrained optimization problem,
where ``x`` is the input pytree.
.. math::
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
||y||_2 = \text{value}
Args:
pytree: pytree to project.
scale: radius of the sphere.
Returns:
projected pytree, with the same structure as ``pytree``.
.. versionadded:: 0.2.4
"""
factor = scale / otu.tree_l2_norm(pytree)
return otu.tree_scalar_mul(factor, pytree)


def projection_l2_ball(pytree: Any, scale: float = 1.0) -> Any:
r"""Projection onto the l2 ball.
This function solves the following constrained optimization problem,
where ``x`` is the input pytree.
.. math::
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
||y||_2 \le \text{scale}
Args:
pytree: pytree to project.
scale: radius of the ball.
Returns:
projected pytree, with the same structure as ``pytree``.
.. versionadded:: 0.2.4
"""
l2_norm = otu.tree_l2_norm(pytree)
factor = scale / l2_norm
return jax.lax.cond(l2_norm <= scale,
lambda pytree: pytree,
lambda pytree: otu.tree_scalar_mul(factor, pytree),
operand=pytree)


def projection_linf_ball(pytree: Any, scale: float = 1.0) -> Any:
r"""Projection onto the l-infinity ball.
This function solves the following constrained optimization problem,
where ``x`` is the input pytree.
.. math::
\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad
||y||_{\infty} \le \text{scale}
Args:
pytree: pytree to project.
scale: radius of the ball.
Returns:
projected pytree, with the same structure as ``pytree``.
"""
lower_tree = otu.tree_full_like(pytree, -scale)
upper_tree = otu.tree_full_like(pytree, scale)
return projection_box(pytree, lower=lower_tree, upper=upper_tree)
61 changes: 45 additions & 16 deletions optax/projections/_projections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,37 +174,66 @@ def test_projection_simplex_vmap(self, scale):
np.testing.assert_array_equal(True, 0 <= p)
np.testing.assert_array_equal(True, p <= scale)

@parameterized.product(data_key=['array_1d', 'array_2d', 'tree'],
scale=[1.0, 3.21])
def test_projection_l1_sphere(self, data_key, scale):
def _test_projection_sphere(self, data_key, scale, proj_fun, norm_fun):
"""Check correctness of the projection onto a sphere."""
x = self.data[data_key]
p = proj.projection_l1_sphere(x, scale)
# Check that the projection has the correct l1 norm.
np.testing.assert_almost_equal(otu.tree_l1_norm(p), scale, decimal=4)
p = proj_fun(x, scale)
np.testing.assert_almost_equal(norm_fun(p), scale, decimal=4)

@parameterized.product(
data_key=['array_1d', 'array_2d', 'tree'], scale=[1.0, 3.21]
)
def test_projection_l1_sphere(self, data_key, scale):
self._test_projection_sphere(
data_key, scale, proj.projection_l1_sphere, otu.tree_l1_norm
)

def _check_projection_ball(self, x, ball_proj, norm_fun):
@parameterized.product(
data_key=['array_1d', 'array_2d', 'tree'], scale=[1.0, 3.21]
)
def test_projection_l2_sphere(self, data_key, scale):
self._test_projection_sphere(
data_key, scale, proj.projection_l2_sphere, otu.tree_l2_norm
)

def _test_projection_ball(self, data_key, proj_fun, norm_fun):
"""Check correctness of the projection onto a ball."""
x = self.data[data_key]
eps = 1e-4

norm_value = norm_fun(x)

with self.subTest('Check when input is already in the ball'):
big_radius = norm_value * 2
p = ball_proj(x, big_radius)
np.testing.assert_array_almost_equal(x, p)
p = proj_fun(x, big_radius)
self.assertLessEqual(norm_fun(p), big_radius + eps)

with self.subTest('Check when input is on the boundary of the ball'):
p = ball_proj(x, norm_value)
np.testing.assert_array_almost_equal(x, p)
p = proj_fun(x, norm_value)
self.assertLessEqual(norm_fun(p), norm_value + eps)

with self.subTest('Check when input is outside the ball'):
small_radius = norm_value / 2
p = ball_proj(x, small_radius)
np.testing.assert_almost_equal(norm_fun(p), small_radius, decimal=4)
p = proj_fun(x, small_radius)
self.assertLessEqual(norm_fun(p), small_radius + eps)

@parameterized.parameters('array_1d', 'array_2d', 'tree')
@parameterized.product(data_key=['array_1d', 'array_2d', 'tree'])
def test_projection_l1_ball(self, data_key):
x = self.data[data_key]
self._check_projection_ball(x, proj.projection_l1_ball, otu.tree_l1_norm)
self._test_projection_ball(
data_key, proj.projection_l1_ball, otu.tree_l1_norm
)

@parameterized.product(data_key=['array_1d', 'array_2d', 'tree'])
def test_projection_l2_ball(self, data_key):
self._test_projection_ball(
data_key, proj.projection_l2_ball, otu.tree_l2_norm
)

@parameterized.product(data_key=['array_1d', 'array_2d', 'tree'])
def test_projection_linf_ball(self, data_key):
self._test_projection_ball(
data_key, proj.projection_linf_ball, otu.tree_linf_norm
)

if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from optax.tree_utils._tree_math import tree_full_like
from optax.tree_utils._tree_math import tree_l1_norm
from optax.tree_utils._tree_math import tree_l2_norm
from optax.tree_utils._tree_math import tree_linf_norm
from optax.tree_utils._tree_math import tree_max
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
from optax.tree_utils._tree_math import tree_scalar_mul
Expand Down
27 changes: 27 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ def tree_sum(tree: Any) -> chex.Numeric:
return jax.tree.reduce(operator.add, sums, initializer=0)


def tree_max(tree: Any) -> chex.Numeric:
"""Compute the max of all the elements in a pytree.
Args:
tree: pytree.
Returns:
a scalar value.
"""
maxes = jax.tree.map(jnp.max, tree)
# initializer=-jnp.inf should work but pytype doesn't like it.
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))


def _square(leaf):
return jnp.square(leaf.real) + jnp.square(leaf.imag)

Expand Down Expand Up @@ -203,6 +217,19 @@ def tree_l1_norm(tree: Any) -> chex.Numeric:
return tree_sum(abs_tree)


def tree_linf_norm(tree: Any) -> chex.Numeric:
"""Compute the l-infinity norm of a pytree.
Args:
tree: pytree.
Returns:
a scalar value.
"""
abs_tree = jax.tree.map(jnp.abs, tree)
return tree_max(abs_tree)


def tree_zeros_like(
tree: Any,
dtype: Optional[jax.typing.DTypeLike] = None,
Expand Down
19 changes: 19 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ def test_tree_sum(self):
got = tu.tree_sum(self.tree_a)
np.testing.assert_allclose(expected, got)

def test_tree_max(self):
expected = jnp.max(self.array_a)
got = tu.tree_max(self.array_a)
np.testing.assert_allclose(expected, got)

expected = max(jnp.max(self.tree_a[0]), jnp.max(self.tree_a[1]))
got = tu.tree_max(self.tree_a)
np.testing.assert_allclose(expected, got)

def test_tree_l2_norm(self):
expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real)
got = tu.tree_l2_norm(self.array_a)
Expand All @@ -153,6 +162,16 @@ def test_tree_l1_norm(self, key):
got = tu.tree_l1_norm(tree)
np.testing.assert_allclose(expected, got, atol=1e-4)

@parameterized.parameters(
'tree_a', 'tree_a_dict', 'tree_b', 'array_a', 'array_b', 'tree_b_dict'
)
def test_tree_linf_norm(self, key):
tree = self.data[key]
values, _ = flatten_util.ravel_pytree(tree)
expected = jnp.max(jnp.abs(values))
got = tu.tree_linf_norm(tree)
np.testing.assert_allclose(expected, got, atol=1e-4)

def test_tree_zeros_like(self):
expected = jnp.zeros_like(self.array_a)
got = tu.tree_zeros_like(self.array_a)
Expand Down

0 comments on commit 259e20f

Please sign in to comment.