Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: l01ball #154

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Orthogonal projections
HyperPlaneBoxProj
IntersectionProj
L0BallProj
L01BallProj
L1BallProj
NuclearBallProj
SimplexProj
Expand Down Expand Up @@ -68,6 +69,7 @@ Convex
Intersection
L0
L0Ball
L01Ball
L1
L1Ball
L2
Expand Down
41 changes: 38 additions & 3 deletions pyproximal/projection/L0.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
from pyproximal.projection import SimplexProj


class L0BallProj():
r"""L0 ball projection.
r""":math:`L_0` ball projection.

Parameters
----------
Expand Down Expand Up @@ -32,4 +31,40 @@ def __call__(self, x):
xshape = x.shape
xf = x.copy().flatten()
xf[np.argsort(np.abs(xf))[:-self.radius]] = 0
return xf.reshape(xshape)
return xf.reshape(xshape)


class L01BallProj():
r""":math:`L_{0,1}` ball projection.

Parameters
----------
radius : :obj:`int`
Radius

Notes
-----
Given an :math:`L_{0,1}` ball defined as:

.. math::

L_{0,1}^{r} =
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1,
||\mathbf{x}_2||_1, ..., ||\mathbf{x}_1||_1] \ne 0) \leq r \}

its orthogonal projection is computed by finding the :math:`r` highest
largest entries of a vector obtained by applying the :math:`L_1` norm to each
column of a matrix :math:`\mathbf{x}` (in absolute value), keeping those
and zero-ing all the other entries.
Note that this is the proximal operator of the corresponding
indicator function :math:`\mathcal{I}_{L_{0,1}^{r}}`.

"""
def __init__(self, radius):
self.radius = int(radius)

def __call__(self, x):
xc = x.copy()
xf = np.linalg.norm(x, axis=0, ord=1)
xc[:, np.argsort(np.abs(xf))[:-self.radius]] = 0
return xc
2 changes: 1 addition & 1 deletion pyproximal/projection/L1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class L1BallProj():
r"""L1 ball projection.
r""":math:`L_1` ball projection.

Parameters
----------
Expand Down
3 changes: 2 additions & 1 deletion pyproximal/projection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HyperPlaneBoxProj Projection onto an intersection beween a HyperPlane and a Box
SimplexProj Projection onto a Simplex
L0Proj Projection onto an L0 Ball
L01Proj Projection onto an L0,1 Ball
L1Proj Projection onto an L1 Ball
EuclideanBallProj Projection onto an Euclidean Ball
NuclearBallProj Projection onto a Nuclear Ball
Expand All @@ -29,5 +30,5 @@


__all__ = ['BoxProj', 'HyperPlaneBoxProj', 'SimplexProj', 'L0BallProj',
'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj',
'L01BallProj', 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj',
'IntersectionProj', 'AffineSetProj', 'HankelProj']
66 changes: 61 additions & 5 deletions pyproximal/proximal/L0.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from pyproximal.ProxOperator import _check_tau
from pyproximal.projection import L0BallProj
from pyproximal.projection import L0BallProj, L01BallProj
from pyproximal import ProxOperator
from pyproximal.proximal.L1 import _current_sigma

Expand Down Expand Up @@ -35,7 +35,7 @@ def _hardthreshold(x, thresh):


class L0(ProxOperator):
r"""L0 norm proximal operator.
r""":math:`L_0` norm proximal operator.

Proximal operator of the :math:`\ell_0` norm:
:math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`.
Expand Down Expand Up @@ -92,7 +92,7 @@ def prox(self, x, tau):


class L0Ball(ProxOperator):
r"""L0 ball proximal operator.
r""":math:`L_0` ball proximal operator.

Proximal operator of the L0 ball: :math:`L0_{r} =
\{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`.
Expand All @@ -103,7 +103,6 @@ class L0Ball(ProxOperator):
Radius. This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has been
invoked before and returns a scalar ``radius`` to be used.
Radius

Notes
-----
Expand Down Expand Up @@ -136,4 +135,61 @@ def prox(self, x, tau):
radius = _current_sigma(self.radius, self.count)
self.ball.radius = radius
y = self.ball(x)
return y
return y


class L01Ball(ProxOperator):
r""":math:`L_{0,1}` ball proximal operator.

Proximal operator of the :math:`L_{0,1}` ball: :math:`L_{0,1}^{r} =
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ...,
||\mathbf{x}_1||_1] \ne 0) \leq r \}`

Parameters
----------
ndim : :obj:`int`
Number of dimensions :math:`N_{dim}`. Used to reshape the input array
in a matrix of size :math:`N_{dim} \times N'_{x}` where
:math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input
vector ``x`` should be created by stacking vectors from different
dimensions.
radius : :obj:`int` or :obj:`func`, optional
Radius. This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has been
invoked before and returns a scalar ``radius`` to be used.

Notes
-----
As the L0 ball is an indicator function, the proximal operator
corresponds to its orthogonal projection
(see :class:`pyproximal.projection.L01BallProj` for details.

"""
def __init__(self, ndim, radius):
super().__init__(None, False)
self.ndim = ndim
self.radius = radius
self.ball = L01BallProj(self.radius if not callable(radius) else radius(0))
self.count = 0

def __call__(self, x, tol=1e-4):
x = x.reshape(self.ndim, len(x) // self.ndim)
radius = _current_sigma(self.radius, self.count)
return np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius

def _increment_count(func):
"""Increment counter
"""
def wrapped(self, *args, **kwargs):
self.count += 1
return func(self, *args, **kwargs)
return wrapped

@_increment_count
@_check_tau
def prox(self, x, tau):
x = x.reshape(self.ndim, len(x) // self.ndim)
radius = _current_sigma(self.radius, self.count)
self.ball.radius = radius
y = self.ball(x)
return y.ravel()
3 changes: 2 additions & 1 deletion pyproximal/proximal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Nonlinear Nonlinear function
L0 L0 Norm
L0Ball L0 Ball
L01pBall L0,1 Ball
L1 L1 Norm
L1Ball L1 Ball
Euclidean Euclidean Norm
Expand Down Expand Up @@ -67,7 +68,7 @@
from .Hankel import *

__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L01Ball', 'L1', 'L1Ball', 'L2',
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
Expand Down
21 changes: 20 additions & 1 deletion pytests/test_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy.testing import assert_array_almost_equal
from pylops.basicoperators import Identity
from pyproximal.utils import moreau
from pyproximal.proximal import Box, EuclideanBall, L0Ball, L1Ball, \
from pyproximal.proximal import Box, EuclideanBall, L0Ball, L01Ball, L1Ball, \
NuclearBall, Simplex, AffineSet, Hankel

par1 = {'nx': 10, 'ny': 8, 'axis': 0, 'dtype': 'float32'} # even float32 dir0
Expand Down Expand Up @@ -65,6 +65,25 @@ def test_L0Ball(par):
assert moreau(l0, x, tau)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_L01Ball(par):
"""L01 Ball projection and proximal/dual proximal of related indicator
"""
np.random.seed(10)

l0 = L01Ball(3, 1)
x = np.random.normal(0., 1., (3, par['nx'])).astype(par['dtype']).ravel() + 1.

# evaluation
assert l0(x) == False
xp = l0.prox(x, 1.)
assert l0(xp) == True

# prox / dualprox
tau = 2.
assert moreau(l0, x, tau)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_L1Ball(par):
"""L1 Ball projection and proximal/dual proximal of related indicator
Expand Down
Loading