diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 2bd1400..7edf0cb 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -10,8 +10,6 @@ build:
os: ubuntu-20.04
tools:
python: "3.9"
- apt_packages:
- - libopenblas-dev
# Build documentation in the docs/ directory with Sphinx
sphinx:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ac6a9ab..327925b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,16 @@
+Changelog
+=========
+
+# 0.9.0
+
+* Added :py:class:`pyproximal.optimization.palm.iPALM` solver
+* Added :py:func:`pyproximal.optimization.palm._backtracking` method to be used when `gammaf=None` and/or `gammag=None`
+* Added :py:func:`pyproximal.utils.gradtest.gradtest_proximal` and :py:func:`pyproximal.utils.gradtest.gradtest_bilinear` methods
+* Added `tol` to :py:class:`pyproximal.optimization.primal.ProximalPoint` and
+ :py:class:`pyproximal.optimization.primal.ProximalGradient` solvers
+* Modified :py:class:`pyproximal.ProxOperator.precomposition` to allow `b` being also a vector
+
+
# 0.8.0
* Added ``pyproximal.projection.L01BallProj`` and ``pyproximal.proximal.L01Ball`` operators
diff --git a/README.md b/README.md
index 46f0872..3149ed3 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,12 @@
![PyProximal](https://github.com/PyLops/pyproximal/blob/dev/docs/source/_static/pyproximal_b.png)
[![PyPI version](https://badge.fury.io/py/pyproximal.svg)](https://badge.fury.io/py/pyproximal)
-[![Build Status](https://travis-ci.com/PyLops/pyproximal.svg?branch=main)](https://travis-ci.com/PyLops/pyproximal)
-[![AzureDevOps Status](https://dev.azure.com/matteoravasi/PyLops/_apis/build/status%2FPyLops.pyproximal?branchName=refs%2Fpull%2F129%2Fmerge)](https://dev.azure.com/matteoravasi/PyLops/_build/latest?definitionId=10&branchName=refs%2Fpull%2F129%2Fmerge)
+[![AzureDevOps Status](https://dev.azure.com/matteoravasi/PyLops/_apis/build/status%2FPyLops.pyproximal?branchName=refs%2Fpull%2F180%2Fmerge)](https://dev.azure.com/matteoravasi/PyLops/_build/latest?definitionId=10&branchName=refs%2Fpull%2F180%2Fmerge)
![GithubAction Status](https://github.com/PyLops/pyproximal/workflows/PyProx/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/pyproximal/badge/?version=latest)](https://pyproximal.readthedocs.io/en/latest/?badge=latest)
[![OS-support](https://img.shields.io/badge/OS-linux,osx-850A8B.svg)](https://github.com/PyLops/pyproximal)
[![Slack Status](https://img.shields.io/badge/chat-slack-green.svg)](https://pylops.slack.com)
-
+[![DOI](https://joss.theoj.org/papers/10.21105/joss.06326/status.svg)](https://doi.org/10.21105/joss.06326)
:vertical_traffic_light: :vertical_traffic_light: This library is under early development.
@@ -184,9 +183,17 @@ make docupdate
Note that if a new example or tutorial is created (and if any change is made to a previously available example or tutorial)
you are required to rebuild the entire documentation before your changes will be visible.
+## Citing
+When using PyProximal in scientific publications, please cite the following paper:
+
+- Ravasi M, Örnhag M. V., Luiken N., Leblanc O. and Uruñuela E., 2024, *PyProximal - scalable convex optimization in Python*,
+ Journal of Open Source Software, 9(95), 6326. doi: 10.21105/joss.06326 [(link)](https://joss.theoj.org/papers/10.21105/joss.06326)
+
+
## Contributors
* Matteo Ravasi, mrava87
* Nick Luiken, NickLuiken
* Eneko Uruñuela, eurunuela
* Marcus Valtonen Örnhag, marcusvaltonen
+* Olivier Leblanc, olivierleblanc
diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst
index 085037a..4000054 100755
--- a/docs/source/api/index.rst
+++ b/docs/source/api/index.rst
@@ -130,6 +130,18 @@ Other operators
BilinearOperator
LowRankFactorizedMatrix
+Utility functions
+-----------------
+
+.. currentmodule:: pyproximal.utils.gradtest
+
+.. autosummary::
+ :toctree: generated/
+
+ gradtest_proximal
+ gradtest_bilinear
+
+
Solvers
-------
@@ -158,6 +170,7 @@ Primal
:toctree: generated/
PALM
+ iPALM
.. currentmodule:: pyproximal.optimization.pnp
diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst
index b64508a..a0fc72b 100644
--- a/docs/source/changelog.rst
+++ b/docs/source/changelog.rst
@@ -3,6 +3,18 @@
Changelog
=========
+Version 0.9.0
+--------------
+*Released on: 16/08/2024*
+
+* Added :py:class:`pyproximal.optimization.palm.iPALM` solver
+* Added :py:func:`pyproximal.optimization.palm._backtracking` method to be used when `gammaf=None` and/or `gammag=None`
+* Added :py:func:`pyproximal.utils.gradtest.gradtest_proximal` and :py:func:`pyproximal.utils.gradtest.gradtest_bilinear` methods
+* Added `tol` to :py:class:`pyproximal.optimization.primal.ProximalPoint` and
+ :py:class:`pyproximal.optimization.primal.ProximalGradient` solvers
+* Modified :py:class:`pyproximal.ProxOperator.precomposition` to allow `b` being also a vector
+
+
Version 0.8.0
--------------
*Released on: 11/03/2024*
diff --git a/docs/source/credits.rst b/docs/source/credits.rst
index 5233020..750a2c1 100644
--- a/docs/source/credits.rst
+++ b/docs/source/credits.rst
@@ -6,4 +6,5 @@ Contributors
* `Matteo Ravasi `_, mrava87
* `Nick Luiken `_, NickLuiken
* `Eneko Uruñuela `_, eurunuela
-* `Marcus Valtonen Örnhag `_, marcusvaltonen
\ No newline at end of file
+* `Marcus Valtonen Örnhag `_, marcusvaltonen
+* `Olivier Leblanc `_, olivierleblanc
\ No newline at end of file
diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml
new file mode 100644
index 0000000..6e2f743
--- /dev/null
+++ b/environment-dev-arm.yml
@@ -0,0 +1,27 @@
+name: pyproximal
+channels:
+ - defaults
+ - conda-forge
+ - numba
+dependencies:
+ - python>=3.8.12
+ - numpy>=1.15.0, <2.0.0
+ - scipy>=1.8.0
+ - pylops>=2.0.0
+ - scikit-image
+ - matplotlib
+ - ipython
+ - pytest
+ - Sphinx
+ - numpydoc
+ - numba
+ - icc_rt
+ - pip:
+ - bm3d
+ - pytest-runner
+ - setuptools_scm
+ - pydata-sphinx-theme
+ - sphinx-gallery
+ - nbsphinx
+ - image
+ - sphinxemoji
\ No newline at end of file
diff --git a/environment-dev.yml b/environment-dev.yml
index cbc794f..3855772 100644
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -5,7 +5,7 @@ channels:
- numba
dependencies:
- python>=3.8.12
- - numpy>=1.15.0
+ - numpy>=1.15.0, <2.0.0
- scipy>=1.8.0
- pylops>=2.0.0
- scikit-image
@@ -17,7 +17,8 @@ dependencies:
- numba
- icc_rt
- pip:
- - bm3d
+ - bm4d<4.2.4 # temporary fix as GLIBC_2.32 not found by readthedocs
+ - bm3d<4.0.2 # temporary fix as GLIBC_2.32 not found by readthedocs
- pytest-runner
- setuptools_scm
- pydata-sphinx-theme
diff --git a/environment.yml b/environment.yml
index b97842e..4ca4350 100644
--- a/environment.yml
+++ b/environment.yml
@@ -3,6 +3,6 @@ channels:
- defaults
dependencies:
- python>=3.8.12
- - numpy>=1.15.0
+ - numpy>=1.15.0, <2.0.0
- scipy>=1.8.0
- pylops>=2.0.0
diff --git a/pyproject.toml b/pyproject.toml
index 6dcf6a7..0b61eea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
dependencies = [
- "numpy >= 1.15.0",
+ "numpy >= 1.15.0, <2.0.0",
"scipy >= 1.8.0",
"pylops >= 2.0.0",
]
diff --git a/pyproximal/ProxOperator.py b/pyproximal/ProxOperator.py
index a0d42e6..eb3c514 100644
--- a/pyproximal/ProxOperator.py
+++ b/pyproximal/ProxOperator.py
@@ -73,9 +73,9 @@ def _proxdual_moreau(self, x, tau, **kwargs):
def prox(self, x, tau, **kwargs):
"""Proximal operator applied to a vector
- The proximal operator can always be computed given its dual
+ The proximal operator can always be computed given its dual
proximal operator using the Moreau decomposition as defined in
- :func:`pyprox.moreau`. For this reason we can easily create a common
+ :func:`pyproximal.moreau`. For this reason we can easily create a common
method for all proximal operators that can be evaluated provided the
dual proximal is implemented.
@@ -83,7 +83,6 @@ def prox(self, x, tau, **kwargs):
be done by simply implementing ``prox`` for a specific proximal
operator, which will overwrite the general method.
-
Parameters
----------
x : :obj:`np.ndarray`
@@ -100,9 +99,9 @@ def proxdual(self, x, tau, **kwargs):
The dual of a proximal operator can always be computed given its
proximal operator using the Moreau decomposition as defined in
- :func:`pyprox.moreau`. For this reason we can easily create a common
+ :func:`pyproximal.moreau`. For this reason we can easily create a common
method for all dual proximal operators that can be evaluated provided
- he proximal is implemented.
+ the proximal is implemented.
However, since the dual of a proximal operator of a function is
equivalent to the proximal operator of the conjugate function, smarter
@@ -163,7 +162,7 @@ def affine_addition(self, v):
if isinstance(v, np.ndarray):
return _SumOperator(self, v)
else:
- return NotImplemented
+ raise NotImplementedError('v must be of type numpy.ndarray')
def postcomposition(self, sigma):
r"""Postcomposition
@@ -191,7 +190,7 @@ def postcomposition(self, sigma):
if isinstance(sigma, float):
return _PostcompositionOperator(self, sigma)
else:
- return NotImplemented
+ raise NotImplementedError('sigma must be of type float')
def precomposition(self, a, b):
r"""Precomposition
@@ -203,8 +202,8 @@ def precomposition(self, a, b):
----------
a : :obj:`float`
Multiplicative scalar
- b : :obj:`float`
- Additive Scalar
+ b : :obj:`float` or obj:`np.ndarray`
+ Additive scalar (or vector)
Notes
-----
@@ -217,10 +216,12 @@ def precomposition(self, a, b):
prox_{a^2 \tau f} (a \mathbf{x} + b) - b)
"""
- if isinstance(a, float) and isinstance(b, float):
+ if isinstance(a, float) and isinstance(b, (float, np.ndarray)):
return _PrecompositionOperator(self, a, b)
else:
- return NotImplemented
+ raise NotImplementedError('a must be of type float and b '
+ 'must be of type float or '
+ 'numpy.ndarray')
def chain(self, g):
r"""Chain
@@ -347,7 +348,7 @@ def __init__(self, f, a, b):
# raise ValueError('First input must be a ProxOperator')
if not isinstance(a, float):
raise ValueError('Second input must be a float')
- if not isinstance(b, float):
+ if not isinstance(b, (float, np.ndarray)):
raise ValueError('Second input must be a float')
self.f, self.a, self.b = f, a, b
super().__init__(None, True if f.grad else False)
diff --git a/pyproximal/optimization/__init__.py b/pyproximal/optimization/__init__.py
index c25ac22..6dd04b7 100644
--- a/pyproximal/optimization/__init__.py
+++ b/pyproximal/optimization/__init__.py
@@ -34,6 +34,7 @@
SR3 Sparse Relaxed Regularized algorithm
PALM Proximal Alternating Linearized Minimization
+ iPALM Inertial Proximal Alternating Linearized Minimization
Finally this subpackage contains also a solver for image segmentation based
on a special use of the Primal-Dual algorithm:
diff --git a/pyproximal/optimization/palm.py b/pyproximal/optimization/palm.py
index 27f93b5..6dc2161 100644
--- a/pyproximal/optimization/palm.py
+++ b/pyproximal/optimization/palm.py
@@ -1,8 +1,41 @@
import time
+import numpy as np
-def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
- niter=10, callback=None, show=False):
+def _backtracking(x, tau, H, proxf, ix, beta=0.5, niterback=10):
+ r"""Backtracking
+
+ Line-search algorithm for finding step sizes in palm algorithms when
+ the Lipschitz constant of the operator is unknown (or expensive to
+ estimate).
+
+ """
+ def ftilde(x, y, f, g, tau, ix):
+ xy = x - y[ix]
+ return f(*y) + np.dot(g, xy) + \
+ (1. / (2. * tau)) * np.linalg.norm(xy) ** 2
+
+ iiterback = 0
+ if ix == 0:
+ grad = H.gradx(x[ix])
+ else:
+ grad = H.grady(x[ix])
+ z = [x_.copy() for x_ in x]
+ while iiterback < niterback:
+ z[ix] = x[ix] - tau * grad
+ if proxf is not None:
+ z[ix] = proxf.prox(z[ix], tau)
+ ft = ftilde(z[ix], x, H, grad, tau, ix)
+ f = H(*z)
+ if f <= ft or tau < 1e-12:
+ break
+ tau *= beta
+ iiterback += 1
+ return z[ix], tau
+
+
+def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1., beta=0.5,
+ niter=10, niterback=100, callback=None, show=False):
r"""Proximal Alternating Linearized Minimization
Solves the following minimization problem using the Proximal Alternating
@@ -30,11 +63,17 @@ def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
y0 : :obj:`numpy.ndarray`
Initial y vector
gammaf : :obj:`float`, optional
- Positive scalar weight for ``f`` function update
+ Positive scalar weight for ``f`` function update.
+ If ``None``, use backtracking
gammag : :obj:`float`, optional
- Positive scalar weight for ``g`` function update
+ Positive scalar weight for ``g`` function update.
+ If ``None``, use backtracking
+ beta : :obj:`float`, optional
+ Backtracking parameter (must be between 0 and 1)
niter : :obj:`int`, optional
Number of iterations of iterative scheme
+ niterback : :obj:`int`, optional
+ Max number of iterations of backtracking
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` and ``y`` are the current model vectors
@@ -61,7 +100,9 @@ def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
Here :math:`c_k=\gamma_f L_x` and :math:`d_k=\gamma_g L_y`, where
:math:`L_x` and :math:`L_y` are the Lipschitz constant of :math:`\nabla_x H`
- and :math:`\nabla_y H`, respectively.
+ and :math:`\nabla_y H`, respectively. When such constants cannot be easily
+ computed, a back-tracking algorithm can be instead employed to find suitable
+ :math:`c_k` and :math:`d_k` parameters.
.. [1] Bolte, J., Sabach, S., and Teboulle, M. "Proximal alternating
linearized minimization for nonconvex and nonsmooth problems",
@@ -75,22 +116,206 @@ def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
'Bilinear operator: %s\n'
'Proximal operator (f): %s\n'
'Proximal operator (g): %s\n'
- 'gammaf = %10e\tgammaf = %10e\tniter = %d\n' %
- (type(H), type(proxf), type(proxg), gammaf, gammag, niter))
+ 'gammaf = %s\tgammag = %s\tniter = %d\n' %
+ (type(H), type(proxf), type(proxg), str(gammaf), str(gammag), niter))
head = ' Itn x[0] y[0] f g H ck dk'
print(head)
+ backtrackingf, backtrackingg = False, False
+ if gammaf is None:
+ backtrackingf = True
+ tauf = 1.
+ ck = 0.
+ if gammaf is None:
+ backtrackingg = True
+ taug = 1.
+ dk = 0.
+
x, y = x0.copy(), y0.copy()
for iiter in range(niter):
- ck = gammaf * H.ly(y)
- x = x - (1 / ck) * H.gradx(x.ravel())
- if proxf is not None:
- x = proxf.prox(x, ck)
+ # x step
+ if not backtrackingf:
+ ck = gammaf * H.ly(y)
+ x = x - (1. / ck) * H.gradx(x)
+ if proxf is not None:
+ x = proxf.prox(x, 1. / ck)
+ else:
+ x, tauf = _backtracking([x, y], tauf, H,
+ proxf, 0, beta=beta,
+ niterback=niterback)
+ # update x parameter in H function
+ H.updatex(x.copy())
+
+ # y step
+ if not backtrackingg:
+ dk = gammag * H.lx(x)
+ y = y - (1. / dk) * H.grady(y)
+ if proxg is not None:
+ y = proxg.prox(y, 1. / dk)
+ else:
+ y, taug = _backtracking([x, y], tauf, H,
+ proxf, 1, beta=beta,
+ niterback=niterback)
+ # update y parameter in H function
+ H.updatey(y.copy())
+
+ # run callback
+ if callback is not None:
+ callback(x, y)
+
+ if show:
+ pf = proxf(x) if proxf is not None else 0.
+ pg = proxg(y) if proxg is not None else 0.
+ if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
+ msg = '%6g %5.5e %5.2e %5.2e %5.2e %5.2e %5.2e %5.2e' % \
+ (iiter + 1, x[0], y[0], pf if pf is not None else 0.,
+ pg if pg is not None else 0., H(x, y), ck, dk)
+ print(msg)
+ if show:
+ print('\nTotal time (s) = %.2f' % (time.time() - tstart))
+ print('---------------------------------------------------------\n')
+ return x, y
+
+
+def iPALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
+ a=[1., 1.], b=None, beta=0.5, niter=10, niterback=100,
+ callback=None, show=False):
+ r"""Inertial Proximal Alternating Linearized Minimization
+
+ Solves the following minimization problem using the Inertial Proximal
+ Alternating Linearized Minimization (iPALM) algorithm:
+
+ .. math::
+
+ \mathbf{x}\mathbf{,y} = \argmin_{\mathbf{x}, \mathbf{y}}
+ f(\mathbf{x}) + g(\mathbf{y}) + H(\mathbf{x}, \mathbf{y})
+
+ where :math:`f(\mathbf{x})` and :math:`g(\mathbf{y})` are any pair of
+ convex functions that have known proximal operators, and
+ :math:`H(\mathbf{x}, \mathbf{y})` is a smooth function.
+
+ Parameters
+ ----------
+ H : :obj:`pyproximal.utils.bilinear.Bilinear`
+ Bilinear function
+ proxf : :obj:`pyproximal.ProxOperator`
+ Proximal operator of f function
+ proxg : :obj:`pyproximal.ProxOperator`
+ Proximal operator of g function
+ x0 : :obj:`numpy.ndarray`
+ Initial x vector
+ y0 : :obj:`numpy.ndarray`
+ Initial y vector
+ gammaf : :obj:`float`, optional
+ Positive scalar weight for ``f`` function update.
+ If ``None``, use backtracking
+ gammag : :obj:`float`, optional
+ Positive scalar weight for ``g`` function update.
+ If ``None``, use backtracking
+ a : :obj:`list`, optional
+ Inertial parameters (:math:`a \in [0, 1]`)
+ beta : :obj:`float`, optional
+ Backtracking parameter (must be between 0 and 1)
+ niter : :obj:`int`, optional
+ Number of iterations of iterative scheme
+ niterback : :obj:`int`, optional
+ Max number of iterations of backtracking
+ callback : :obj:`callable`, optional
+ Function with signature (``callback(x)``) to call after each iteration
+ where ``x`` and ``y`` are the current model vectors
+ show : :obj:`bool`, optional
+ Display iterations log
+
+ Returns
+ -------
+ x : :obj:`numpy.ndarray`
+ Inverted x vector
+ y : :obj:`numpy.ndarray`
+ Inverted y vector
+
+ Notes
+ -----
+ iPALM [1]_ can be expressed by the following recursion:
+
+ .. math::
+
+ \mathbf{x}_z^k = \mathbf{x}^k + \alpha_x (\mathbf{x}^k - \mathbf{x}^{k-1})\\
+ \mathbf{x}^{k+1} = \prox_{c_k f}(\mathbf{x}_z^k -
+ \frac{1}{c_k}\nabla_x H(\mathbf{x}_z^k, \mathbf{y}^{k}))\\
+ \mathbf{y}_z^k = \mathbf{y}^k + \alpha_y (\mathbf{y}^k - \mathbf{y}^{k-1})\\
+ \mathbf{y}^{k+1} = \prox_{d_k g}(\mathbf{y}_z^k -
+ \frac{1}{d_k}\nabla_y H(\mathbf{x}^{k+1}, \mathbf{y}_z^k))
+
+ Here :math:`c_k=\gamma_f L_x` and :math:`d_k=\gamma_g L_y`, where
+ :math:`L_x` and :math:`L_y` are the Lipschitz constant of :math:`\nabla_x H`
+ and :math:`\nabla_y H`, respectively. When such constants cannot be easily
+ computed, a back-tracking algorithm can be instead employed to find suitable
+ :math:`c_k` and :math:`d_k` parameters.
+
+ Finally, note that we have implemented the version of iPALM where :math:`\beta_x=\alpha_x`
+ and :math:`\beta_y=\alpha_y`.
+
+ .. [1] Pock, T., and Sabach, S. "Inertial Proximal
+ Alternating Linearized Minimization (iPALM) for Nonconvex and
+ Nonsmooth Problems", SIAM Journal on Imaging Sciences, vol. 9. 2016.
+
+ """
+ if show:
+ tstart = time.time()
+ print('iPALM algorithm\n'
+ '---------------------------------------------------------\n'
+ 'Bilinear operator: %s\n'
+ 'Proximal operator (f): %s\n'
+ 'Proximal operator (g): %s\n'
+ 'gammaf = %s\tgammag = %s\n'
+ 'a = %s\tniter = %d\n' %
+ (type(H), type(proxf), type(proxg), str(gammaf), str(gammag), str(a), niter))
+ head = ' Itn x[0] y[0] f g H ck dk'
+ print(head)
+
+ backtrackingf, backtrackingg = False, False
+ if gammaf is None:
+ backtrackingf = True
+ tauf = 1.
+ ck = 0.
+ if gammaf is None:
+ backtrackingg = True
+ taug = 1.
+ dk = 0.
+
+ x, y = x0.copy(), y0.copy()
+ xold, yold = x0.copy(), y0.copy()
+ for iiter in range(niter):
+ # x step
+ z = x + a[0] * (x - xold)
+ if not backtrackingf:
+ ck = gammaf * H.ly(y)
+ xold = x.copy()
+ x = z - (1. / ck) * H.gradx(z)
+ if proxf is not None:
+ x = proxf.prox(x, 1. / ck)
+ else:
+ xold = x.copy()
+ x, tauf = _backtracking([z, y], tauf, H,
+ proxf, 0, beta=beta,
+ niterback=niterback)
+ # update x parameter in H function
H.updatex(x.copy())
- dk = gammag * H.lx(x)
- y = y - (1 / dk) * H.grady(y.ravel())
- if proxg is not None:
- y = proxg.prox(y, dk)
+
+ # y step
+ z = y + a[1] * (y - yold)
+ if not backtrackingg:
+ dk = gammag * H.lx(x)
+ yold = y.copy()
+ y = z - (1. / dk) * H.grady(z)
+ if proxg is not None:
+ y = proxg.prox(y, 1. / dk)
+ else:
+ yold = y.copy()
+ y, taug = _backtracking([x, z], tauf, H,
+ proxf, 1, beta=beta,
+ niterback=niterback)
+ # update y parameter in H function
H.updatey(y.copy())
# run callback
diff --git a/pyproximal/optimization/primal.py b/pyproximal/optimization/primal.py
index 8dccdab..1bee7e0 100644
--- a/pyproximal/optimization/primal.py
+++ b/pyproximal/optimization/primal.py
@@ -33,7 +33,8 @@ def ftilde(x, y, f, tau):
return z, tau
-def ProximalPoint(prox, x0, tau, niter=10, callback=None, show=False):
+def ProximalPoint(prox, x0, tau, niter=10,
+ tol=None, callback=None, show=False):
r"""Proximal point algorithm
Solves the following minimization problem using Proximal point algorithm:
@@ -55,6 +56,9 @@ def ProximalPoint(prox, x0, tau, niter=10, callback=None, show=False):
Positive scalar weight
niter : :obj:`int`, optional
Number of iterations of iterative scheme
+ tol : :obj:`float`, optional
+ Tolerance on change of objective function (used as stopping criterion). If
+ ``tol=None``, run until ``niter`` is reached
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
@@ -80,10 +84,17 @@ def ProximalPoint(prox, x0, tau, niter=10, callback=None, show=False):
print('Proximal point algorithm\n'
'---------------------------------------------------------\n'
'Proximal operator: %s\n'
- 'tau = %10e\tniter = %d\n' % (type(prox), tau, niter))
- head = ' Itn x[0] f'
+ 'tau = %10e\tniter = %d\ttol = %s\n' % (type(prox), tau, niter, str(tol)))
+ head = ' Itn x[0] f'
print(head)
+
+
+ # initialize model
x = x0.copy()
+ pf = np.inf
+ tolbreak = False
+
+ # iterate
for iiter in range(niter):
x = prox.prox(x, tau)
@@ -91,11 +102,27 @@ def ProximalPoint(prox, x0, tau, niter=10, callback=None, show=False):
if callback is not None:
callback(x)
+ # tolerance check: break iterations if overall
+ # objective does not decrease below tolerance
+ if tol is not None:
+ pfold = pf
+ pf = prox(x)
+ if np.abs(1.0 - pf / pfold) < tol:
+ tolbreak = True
+
+ # show iteration logger
if show:
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
+ if tol is None:
+ pf = prox(x)
msg = '%6g %12.5e %10.3e' % \
- (iiter + 1, x[0], prox(x))
+ (iiter + 1, x[0], pf)
print(msg)
+
+ # break if tolerance condition is met
+ if tolbreak:
+ break
+
if show:
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
print('---------------------------------------------------------\n')
@@ -103,9 +130,10 @@ def ProximalPoint(prox, x0, tau, niter=10, callback=None, show=False):
def ProximalGradient(proxf, proxg, x0, epsg=1.,
- tau=None, beta=0.5, eta=1.,
+ tau=None, backtracking=False,
+ beta=0.5, eta=1.,
niter=10, niterback=100,
- acceleration=None,
+ acceleration=None, tol=None,
callback=None, show=False):
r"""Proximal gradient (optionally accelerated)
@@ -137,6 +165,10 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
backtracking is used to adaptively estimate the best tau at each
iteration. Finally, note that :math:`\tau` can be chosen to be a vector
when dealing with problems with multiple right-hand-sides
+ backtracking : :obj:`bool`, optional
+ Force backtracking, even if ``tau`` is not equal to ``None``. In this case
+ the chosen ``tau`` will be used as the initial guess in the first
+ step of backtracking
beta : :obj:`float`, optional
Backtracking parameter (must be between 0 and 1)
eta : :obj:`float`, optional
@@ -147,6 +179,9 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
Max number of iterations of backtracking
acceleration : :obj:`str`, optional
Acceleration (``None``, ``vandenberghe`` or ``fista``)
+ tol : :obj:`float`, optional
+ Tolerance on change of objective function (used as stopping criterion). If
+ ``tol=None``, run until ``niter`` is reached
callback : :obj:`callable`, optional
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
@@ -190,7 +225,7 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
- ``acceleration=None``: :math:`\omega^k = 0`;
- ``acceleration=vandenberghe`` [1]_: :math:`\omega^k = k / (k + 3)` for `
- - ``acceleration=fista``: :math:`\omega^k = (t_{k-1}-1)/t_k` for where
+ - ``acceleration=fista``: :math:`\omega^k = (t_{k-1}-1)/t_k` where
:math:`t_k = (1 + \sqrt{1+4t_{k-1}^{2}}) / 2` [2]_
.. [1] Vandenberghe, L., "Fast proximal gradient methods", 2010.
@@ -215,16 +250,16 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
'---------------------------------------------------------\n'
'Proximal operator (f): %s\n'
'Proximal operator (g): %s\n'
- 'tau = %s\tbeta=%10e\n'
- 'epsg = %s\tniter = %d\n'
+ 'tau = %s\tbacktrack = %s\tbeta = %10e\n'
+ 'epsg = %s\tniter = %d\ttol = %s\n'
''
'niterback = %d\tacceleration = %s\n' % (type(proxf), type(proxg),
- 'Adaptive' if tau is None else str(tau), beta,
- epsg_print, niter, niterback, acceleration))
+ str(tau), backtracking, beta,
+ epsg_print, niter, str(tol),
+ niterback, acceleration))
head = ' Itn x[0] f g J=f+eps*g tau'
print(head)
- backtracking = False
if tau is None:
backtracking = True
tau = 1.
@@ -233,6 +268,8 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
t = 1.
x = x0.copy()
y = x.copy()
+ pfg = np.inf
+ tolbreak = False
# iterate
for iiter in range(niter):
@@ -269,15 +306,32 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
if callback is not None:
callback(x)
+ # tolerance check: break iterations if overall
+ # objective does not decrease below tolerance
+ if tol is not None:
+ pfgold = pfg
+ pf, pg = proxf(x), proxg(x)
+ pfg = pf + np.sum(epsg[iiter] * pg)
+ if np.abs(1.0 - pfg / pfgold) < tol:
+ tolbreak = True
+
+ # show iteration logger
if show:
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
- pf, pg = proxf(x), proxg(x)
+ if tol is None:
+ pf, pg = proxf(x), proxg(x)
+ pfg = pf + np.sum(epsg[iiter] * pg)
msg = '%6g %12.5e %10.3e %10.3e %10.3e %10.3e' % \
(iiter + 1, np.real(to_numpy(x[0])) if x.ndim == 1 else np.real(to_numpy(x[0, 0])),
pf, pg,
- pf + np.sum(epsg[iiter] * pg),
+ pfg,
tau)
print(msg)
+
+ # break if tolerance condition is met
+ if tolbreak:
+ break
+
if show:
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
print('---------------------------------------------------------\n')
@@ -286,7 +340,7 @@ def ProximalGradient(proxf, proxg, x0, epsg=1.,
def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
epsg=1., niter=10, niterback=100,
- acceleration='vandenberghe',
+ acceleration='vandenberghe', tol=None,
callback=None, show=False):
r"""Accelerated Proximal gradient
@@ -301,7 +355,7 @@ def AcceleratedProximalGradient(proxf, proxg, x0, tau=None, beta=0.5,
'version v1.0.0 and AcceleratedProximalGradient will be removed.', FutureWarning)
return ProximalGradient(proxf, proxg, x0, tau=tau, beta=beta,
epsg=epsg, niter=niter, niterback=niterback,
- acceleration=acceleration,
+ acceleration=acceleration, tol=tol,
callback=callback, show=show)
diff --git a/pyproximal/utils/bilinear.py b/pyproximal/utils/bilinear.py
index 278fcb3..e36371f 100644
--- a/pyproximal/utils/bilinear.py
+++ b/pyproximal/utils/bilinear.py
@@ -115,12 +115,12 @@ def __init__(self, X, Y, d, Op=None):
self.y = Y
self.d = d
self.Op = Op
- self.shapex = (self.n * self.m, self.n * self.k)
- self.shapey = (self.n * self.m, self.m * self.k)
+ self.sizex = self.n * self.k
+ self.sizey = self.m * self.k
def __call__(self, x, y=None):
if y is None:
- x, y = x[:self.n * self.k], x[self.n * self.k:]
+ x, y = x[:self.n * self.k], x[self.n * self.k:]
xold = self.x.copy()
self.updatex(x)
res = self.d - self._matvecy(y)
@@ -147,25 +147,23 @@ def matvec(self, x):
'cannot distinguish automatically'
'between _matvecx and _matvecy. '
'Explicitely call either of those two methods.')
- if x.size == self.shapex[1]:
+ if x.size == self.sizex:
y = self._matvecx(x)
else:
y = self._matvecy(x)
return y
def lx(self, x):
+ if self.Op is not None:
+ ValueError('lx cannot be computed when using Op')
X = x.reshape(self.n, self.k)
- # TODO: not clear how to handle Op
- #if self.Op is not None:
- # X = self.Op @ X
return np.linalg.norm(np.conj(X).T @ X, 'fro')
def ly(self, y):
- Y = np.conj(y.reshape(self.k, self.m)).T
- # TODO: not clear how to handle Op
- #if self.Op is not None:
- # Y = self.Op.H @ Y
- return np.linalg.norm(np.conj(Y).T @ Y, 'fro')
+ if self.Op is not None:
+ ValueError('ly cannot be computed when using Op')
+ Y = y.reshape(self.k, self.m)
+ return np.linalg.norm(Y @ np.conj(Y).T, 'fro')
def gradx(self, x):
r = self.d - self._matvecx(x)
diff --git a/pyproximal/utils/gradtest.py b/pyproximal/utils/gradtest.py
new file mode 100644
index 0000000..03b5910
--- /dev/null
+++ b/pyproximal/utils/gradtest.py
@@ -0,0 +1,231 @@
+import numpy as np
+
+from pylops.utils.backend import get_module, to_numpy
+
+
+def gradtest_proximal(Op, n, x=None, dtype="float64",
+ delta=1e-6, rtol=1e-6, atol=1e-21,
+ complexflag=False, raiseerror=True,
+ verb=False, backend="numpy"):
+ r"""Gradient test for Proximal operator.
+
+ Compute the gradient of ``Op`` using both the provided method and a
+ numerical approximation with a perturbation ``delta`` applied to a
+ single, randomly selected parameter of the input vector.
+
+ Parameters
+ ----------
+ Op : :obj:`pyproximal.Proximal`
+ Proximal operator to test.
+ n : :obj:`int`
+ Size of input vector
+ x : :obj:`numpy.ndarray`, optional
+ Input vector (if ``None``, randomly drawn from a
+ Normal distribution)
+ dtype : :obj:`str`, optional
+ Dtype of vector ``x`` to generate (only used when ``x=None``)
+ delta : :obj:`float`, optional
+ Perturbation
+ rtol : :obj:`float`, optional
+ Relative gradtest tolerance
+ atol : :obj:`float`, optional
+ Absolute gradtest tolerance
+ complexflag : :obj:`bool`, optional
+ Generate random vectors with real (``False``) or
+ complex (``True``) entries
+ raiseerror : :obj:`bool`, optional
+ Raise error or simply return ``False`` when dottest fails
+ verb : :obj:`bool`, optional
+ Verbosity
+ backend : :obj:`str`, optional
+ Backend used for dot test computations (``numpy`` or ``cupy``). This
+ parameter will be used to choose how to create the random vectors.
+
+ Returns
+ -------
+ passed : :obj:`bool`
+ Passed flag.
+
+ Raises
+ ------
+ AssertionError
+ If grad-test is not verified within chosen tolerances.
+
+ Notes
+ -----
+ A gradient-test is mathematical tool used in the development of numerical
+ nonliner operators.
+
+ More specifically, a correct implementation of the gradient for
+ a nonlinear operator should verify the following *equality*
+ within a numerical tolerance:
+
+ .. math::
+ \frac{\partial Op(\mathbf{x})}{\partial \mathbf{x}} =
+ \frac{Op(\mathbf{x}+\delta \mathbf{x})-Op(\mathbf{x})}{\delta \mathbf{x}}
+
+ """
+ ncp = get_module(backend)
+
+ # get random vectors for x and y
+ if x is None:
+ x = np.random.normal(0., 1., n).astype(dtype)
+
+ if complexflag:
+ x = x + 1j * np.random.normal(0., 1., n).astype(dtype)
+
+ # compute function
+ f = Op(x)
+
+ # compute gradient
+ g = Op.grad(x)
+
+ # choose location of perturbation, whether to act on x or y and on real or imag part
+ iqx = np.random.randint(0, n)
+ r_or_i = np.random.randint(0, 2)
+
+ if r_or_i == 0:
+ delta1 = delta
+ else:
+ delta1 = delta * 1j
+
+ # extract gradient value to test
+ x[iqx] = x[iqx] + delta1
+ grad = g[iqx]
+
+ # compute new function at perturbed location
+ fdelta = Op(x)
+
+ # evaluate if gradient test passed
+ grad_delta = (fdelta - f) / np.abs(delta)
+ grad_diff = grad_delta - (grad.real if r_or_i == 0 else grad.imag)
+ passed = np.isclose(grad_diff, 0, rtol, atol)
+
+ # verbosity or error raising
+ if (not passed and raiseerror) or verb:
+ passed_status = "passed" if passed else "failed"
+ msg = f"Grad test {passed_status}, Analytic={grad.real if r_or_i == 0 else grad.imag} - " \
+ f"Numeric={grad_delta}"
+ if not passed and raiseerror:
+ raise AssertionError(msg)
+ else:
+ print(msg)
+
+ return passed
+
+
+def gradtest_bilinear(Op, delta=1e-6, rtol=1e-6, atol=1e-21,
+ complexflag=False, raiseerror=True,
+ verb=False, backend="numpy"):
+ r"""Gradient test for Bilinear operator.
+
+ Compute the gradient of ``Op`` using both the provided method and a
+ numerical approximation with a perturbation ``delta`` applied to a
+ single, randomly selected parameter of either the ``x`` or ``y``
+ vectors.
+
+ Parameters
+ ----------
+ Op : :obj:`pyproximal.utils.BilinearOperator`
+ Bilinear operator to test.
+ delta : :obj:`float`, optional
+ Perturbation
+ rtol : :obj:`float`, optional
+ Relative gradtest tolerance
+ atol : :obj:`float`, optional
+ Absolute gradtest tolerance
+ complexflag : :obj:`bool`, optional
+ Generate random vectors with real (``False``) or
+ complex (``True``) entries
+ raiseerror : :obj:`bool`, optional
+ Raise error or simply return ``False`` when dottest fails
+ verb : :obj:`bool`, optional
+ Verbosity
+ backend : :obj:`str`, optional
+ Backend used for dot test computations (``numpy`` or ``cupy``). This
+ parameter will be used to choose how to create the random vectors.
+
+ Returns
+ -------
+ passed : :obj:`bool`
+ Passed flag.
+
+ Raises
+ ------
+ AssertionError
+ If grad-test is not verified within chosen tolerances.
+
+ Notes
+ -----
+ A gradient-test is mathematical tool used in the development of numerical
+ bilinear operators.
+
+ More specifically, a correct implementation of the gradient for
+ a bilinear operator should verify the following *equalities*
+ within a numerical tolerance:
+
+ .. math::
+ \frac{\partial Op(\mathbf{x})}{\partial \mathbf{x}} =
+ \frac{Op(\mathbf{x}+\delta \mathbf{x}, \mathbf{y})-
+ Op(\mathbf{x})}{\delta \mathbf{x}, \mathbf{y}}
+
+ and
+
+ .. math::
+ \frac{\partial Op(\mathbf{x}, \mathbf{y})}{\partial \mathbf{y}} =
+ \frac{Op(\mathbf{x}, \mathbf{y}+\delta \mathbf{y})-
+ Op(\mathbf{x}, \mathbf{y})}{\delta \mathbf{y}}
+
+ """
+ ncp = get_module(backend)
+
+ nx = Op.sizex
+ ny = Op.sizey
+
+ # extract x and y from Op
+ x, y = Op.x.ravel(), Op.y.ravel()
+
+ # compute function at x and y
+ f = Op(x, y)
+
+ # compute gradients at x and y
+ gx = Op.gradx(x)
+ gy = Op.grady(y)
+
+ # choose location of perturbation, whether to act on x or y and on real or imag part
+ iqx, iqy = np.random.randint(0, nx), np.random.randint(0, ny)
+ x_or_y = np.random.randint(0, 2)
+
+ delta1 = delta
+ if complexflag:
+ r_or_i = np.random.randint(0, 2)
+ if r_or_i == 1:
+ delta1 = delta * 1j
+
+ # extract gradient value to test
+ if x_or_y == 0:
+ x[iqx] = x[iqx] + delta1
+ grad = gx[iqx]
+ else:
+ y[iqy] = y[iqy] + delta1
+ grad = gy[iqy]
+
+ # compute new function at perturbed location
+ fdelta = Op(x, y)
+
+ # evaluate if gradient test passed
+ grad_delta = (fdelta - f) / np.abs(delta)
+ grad_diff = grad_delta - (grad.real if not complexflag or r_or_i == 0 else grad.imag)
+ passed = np.isclose(grad_diff, 0, rtol, atol)
+
+ # verbosity or error raising
+ if (not passed and raiseerror) or verb:
+ passed_status = "passed" if passed else "failed"
+ msg = f"Grad test {passed_status}, Analytic={grad.real if r_or_i == 0 else grad.imag} - " \
+ f"Numeric={grad_delta}"
+ if not passed and raiseerror:
+ raise AssertionError(msg)
+ else:
+ print(msg)
+
+ return passed
diff --git a/pytests/test_grads.py b/pytests/test_grads.py
new file mode 100644
index 0000000..88b8fd4
--- /dev/null
+++ b/pytests/test_grads.py
@@ -0,0 +1,69 @@
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_almost_equal
+
+from pylops.basicoperators import MatrixMult
+from pyproximal.proximal import L2
+from pyproximal.utils.bilinear import LowRankFactorizedMatrix
+from pyproximal.utils.gradtest import gradtest_proximal, gradtest_bilinear
+
+par1 = {'nx': 10, 'imag': 0, 'complexflag': False, 'dtype': 'float32'} # even float32
+par2 = {'nx': 11, 'imag': 0, 'complexflag': False, 'dtype': 'float64'} # odd float64
+par1j = {'nx': 10, 'imag': 1j, 'complexflag': True, 'dtype': 'complex64'} # even complex64
+par2j = {'nx': 11, 'imag': 1j, 'complexflag': True, 'dtype': 'float64'} # odd complex128
+
+ngrads = 20 # number of gradient tests
+
+np.random.seed(10)
+
+
+@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
+def test_l2(par):
+ """L2 gradient
+ """
+
+ # x
+ l2 = L2()
+ for _ in range(ngrads):
+ gradtest_proximal(l2, par['nx'],
+ delta=1e-6, complexflag=par['complexflag'],
+ raiseerror=True, atol=1e-3,
+ verb=False)
+
+ # x - b
+ b = np.ones(par['nx'], dtype=par['dtype'])
+ l2 = L2(b=b)
+ for _ in range(ngrads):
+ gradtest_proximal(l2, par['nx'],
+ delta=1e-6, complexflag=par['complexflag'],
+ raiseerror=True, atol=1e-3,
+ verb=False)
+
+ # Opx - b
+ Op = MatrixMult(np.random.normal(0, 1, (2 * par['nx'], par['nx'])) +
+ par['imag'] * np.random.normal(0, 1, (2 * par['nx'], par['nx'])),
+ dtype=par['dtype'])
+ b = np.ones(2 * par['nx'], dtype=par['dtype'])
+ l2 = L2(b=b, Op=Op)
+ for _ in range(ngrads):
+ gradtest_proximal(l2, par['nx'],
+ delta=1e-6, complexflag=par['complexflag'],
+ raiseerror=True, atol=1e-3,
+ verb=False)
+
+@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
+def test_lowrank(par):
+ """LowRankFactorizedMatrix gradient
+ """
+ n, m, k = 2 * par['nx'], par['nx'], par['nx'] // 2
+ x = np.random.normal(0, 1, (n, k)) + par['imag'] * np.random.normal(0, 1, (n, k))
+ y = np.random.normal(0, 1, (k, m)) + par['imag'] * np.random.normal(0, 1, (k, m))
+ d = np.random.normal(0, 1, (n, m)) + par['imag'] * np.random.normal(0, 1, (n, m))
+
+ hop = LowRankFactorizedMatrix(x.copy(), y.copy(), d.ravel())
+
+ for _ in range(ngrads):
+ gradtest_bilinear(hop, delta=1e-6, complexflag=par['complexflag'],
+ raiseerror=True, atol=1e-3,
+ verb=False)
\ No newline at end of file
diff --git a/pytests/test_projection.py b/pytests/test_projection.py
index e2d4f81..b03316a 100644
--- a/pytests/test_projection.py
+++ b/pytests/test_projection.py
@@ -27,7 +27,7 @@ def test_Box(par):
assert moreau(box, x, tau)
-@pytest.mark.parametrize("par", [(par1), (par2)])
+@pytest.mark.parametrize("par", [(par1), ])
def test_EuclBall(par):
"""Euclidean Ball projection and proximal/dual proximal of related indicator
"""
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 925b886..ad0b3f5 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,11 +1,12 @@
-numpy>=1.15.0
+numpy>=1.15.0, <2.0.0
scipy>=1.8.0
pylops>=2.0.0
numba
scikit-image
matplotlib
ipython
-bm3d
+bm4d<4.2.4 # temporary as gclib problem arises in readthedocs
+bm3d<4.0.2 # temporary as gclib problem arises in readthedocs
pytest
pytest-runner
setuptools_scm
diff --git a/tutorials/matrixfactorization.py b/tutorials/matrixfactorization.py
index fae93fd..629987e 100644
--- a/tutorials/matrixfactorization.py
+++ b/tutorials/matrixfactorization.py
@@ -26,6 +26,11 @@
plt.close('all')
np.random.seed(10)
+
+def callback(x, y, n, m, k, xtrue, snr_hist):
+ snr_hist.append(pylops.utils.metrics.snr(xtrue, x.reshape(n, k) @ y.reshape(k, m)))
+
+
###############################################################################
# Let's start by creating the matrix we want to factorize
n, m, k = 100, 90, 10
@@ -48,29 +53,103 @@
###############################################################################
# We are now ready to run the PALM algorithm
-Xest, Yest = \
+snr_palm = []
+Xpalm, Ypalm = \
pyproximal.optimization.palm.PALM(Hop, nn1, nn2, Xin.ravel(), Yin.ravel(),
- gammaf=2, gammag=2, niter=2000, show=True)
-Xest, Yest = Xest.reshape(Xin.shape), Yest.reshape(Yin.shape)
-Aest = Xest @ Yest
+ gammaf=2, gammag=2, niter=2000, show=True,
+ callback=lambda x, y: callback(x, y, n, m, k,
+ A, snr_palm))
+Xpalm, Ypalm = Xpalm.reshape(Xin.shape), Ypalm.reshape(Yin.shape)
+Apalm = Xpalm @ Ypalm
+
+fig, axs = plt.subplots(1, 5, figsize=(14, 3))
+fig.suptitle('PALM')
+axs[0].imshow(Xpalm, cmap='gray')
+axs[0].set_title('Xest')
+axs[0].axis('tight')
+axs[1].imshow(Ypalm, cmap='gray')
+axs[1].set_title('Yest')
+axs[1].axis('tight')
+axs[2].imshow(A, cmap='gray', vmin=10, vmax=37)
+axs[2].set_title('True')
+axs[2].axis('tight')
+axs[3].imshow(Apalm, cmap='gray', vmin=10, vmax=37)
+axs[3].set_title('Reconstructed')
+axs[3].axis('tight')
+axs[4].imshow(A - Apalm, cmap='gray', vmin=-.1, vmax=.1)
+axs[4].set_title('Reconstruction error')
+axs[4].axis('tight')
+fig.tight_layout()
###############################################################################
-# And finally we display the individual components and the reconstructed matrix
+# Similarly we run the PALM algorithm with backtracking
+snr_palmbt = []
+Xpalmbt, Ypalmbt = \
+ pyproximal.optimization.palm.PALM(Hop, nn1, nn2, Xin.ravel(), Yin.ravel(),
+ gammaf=None, gammag=None, niter=2000, show=True,
+ callback=lambda x, y: callback(x, y, n, m, k,
+ A, snr_palmbt))
+Xpalmbt, Ypalmbt = Xpalmbt.reshape(Xin.shape), Ypalmbt.reshape(Yin.shape)
+Apalmbt = Xpalmbt @ Ypalmbt
fig, axs = plt.subplots(1, 5, figsize=(14, 3))
-axs[0].imshow(Xest, cmap='gray')
+fig.suptitle('PALM with back-tracking')
+axs[0].imshow(Xpalmbt, cmap='gray')
axs[0].set_title('Xest')
axs[0].axis('tight')
-axs[1].imshow(Yest, cmap='gray')
+axs[1].imshow(Ypalmbt, cmap='gray')
axs[1].set_title('Yest')
axs[1].axis('tight')
-axs[2].imshow(A, cmap='gray')
+axs[2].imshow(A, cmap='gray', vmin=10, vmax=37)
axs[2].set_title('True')
axs[2].axis('tight')
-axs[3].imshow(Aest, cmap='gray')
+axs[3].imshow(Apalmbt, cmap='gray', vmin=10, vmax=37)
axs[3].set_title('Reconstructed')
axs[3].axis('tight')
-axs[4].imshow(A-Aest, cmap='gray')
+axs[4].imshow(A - Apalmbt, cmap='gray', vmin=-.1, vmax=.1)
axs[4].set_title('Reconstruction error')
axs[4].axis('tight')
fig.tight_layout()
+
+###############################################################################
+# And the iPALM algorithm
+snr_ipalm = []
+Xipalm, Yipalm = \
+ pyproximal.optimization.palm.iPALM(Hop, nn1, nn2, Xin.ravel(), Yin.ravel(),
+ gammaf=2, gammag=2, a=[0.8, 0.8],
+ niter=2000, show=True,
+ callback=lambda x, y: callback(x, y, n, m, k,
+ A, snr_ipalm))
+Xipalm, Yipalm = Xipalm.reshape(Xin.shape), Yipalm.reshape(Yin.shape)
+Aipalm = Xipalm @ Yipalm
+
+fig, axs = plt.subplots(1, 5, figsize=(14, 3))
+fig.suptitle('iPALM')
+axs[0].imshow(Xipalm, cmap='gray')
+axs[0].set_title('Xest')
+axs[0].axis('tight')
+axs[1].imshow(Yipalm, cmap='gray')
+axs[1].set_title('Yest')
+axs[1].axis('tight')
+axs[2].imshow(A, cmap='gray', vmin=10, vmax=37)
+axs[2].set_title('True')
+axs[2].axis('tight')
+axs[3].imshow(Aipalm, cmap='gray', vmin=10, vmax=37)
+axs[3].set_title('Reconstructed')
+axs[3].axis('tight')
+axs[4].imshow(A - Aipalm, cmap='gray', vmin=-.1, vmax=.1)
+axs[4].set_title('Reconstruction error')
+axs[4].axis('tight')
+fig.tight_layout()
+
+###############################################################################
+# And finally compare the converge behaviour of the three methods
+fig, ax = plt.subplots(1, 1, figsize=(8, 5))
+ax.plot(snr_palm, 'k', lw=2, label='PALM')
+ax.plot(snr_palmbt, 'r', lw=2, label='PALM')
+ax.plot(snr_ipalm, 'g', lw=2, label='iPALM')
+ax.grid()
+ax.legend()
+ax.set_title('SNR')
+ax.set_xlabel('# Iteration')
+fig.tight_layout()