-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from sfarrens/new_release
Added pycodestyle tests and updated release
- Loading branch information
Showing
11 changed files
with
147 additions
and
162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,9 +23,9 @@ ModOpt | |
|
||
:Author: Samuel Farrens `([email protected]) <[email protected]>`_ | ||
|
||
:Version: 1.2.0 | ||
:Version: 1.3.0 | ||
|
||
:Date: 21/11/2018 | ||
:Date: 27/03/2019 | ||
|
||
:Documentation: |link-to-docs| | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,9 @@ ModOpt Documentation | |
|
||
:Author: Samuel Farrens <[email protected]> | ||
|
||
:Version: 1.2.0 | ||
:Version: 1.3.0 | ||
|
||
:Date: 21/11/2018 | ||
:Date: 27/03/2019 | ||
|
||
ModOpt is a series of Modular Optimisation tools for solving inverse problems. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,12 +6,12 @@ | |
:Author: Samuel Farrens <[email protected]> | ||
:Version: 1.2.0 | ||
:Version: 1.3.0 | ||
""" | ||
|
||
# Package Version | ||
version_info = (1, 2, 0) | ||
version_info = (1, 3, 0) | ||
__version__ = '.'.join(str(c) for c in version_info) | ||
|
||
__about__ = ('ModOpt \n\n ' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ | |
This module contains class implementations of various optimisation algoritms. | ||
:Author: Samuel Farrens <[email protected]>, Zaccharie Ramzi <[email protected]> | ||
:Author: Samuel Farrens <[email protected]>, | ||
Zaccharie Ramzi <[email protected]> | ||
NOTES | ||
----- | ||
|
@@ -260,58 +261,43 @@ class FISTA(object): | |
None, # no restarting | ||
] | ||
|
||
def __init__( | ||
self, | ||
restart_strategy=None, | ||
min_beta=None, | ||
s_greedy=None, | ||
xi_restart=None, | ||
a_cd=None, | ||
p_lazy=1, | ||
q_lazy=1, | ||
r_lazy=4, | ||
): | ||
def __init__(self, restart_strategy=None, min_beta=None, s_greedy=None, | ||
xi_restart=None, a_cd=None, p_lazy=1, q_lazy=1, r_lazy=4): | ||
|
||
if isinstance(a_cd, type(None)): | ||
self.mode = 'regular' | ||
self.p_lazy = p_lazy | ||
self.q_lazy = q_lazy | ||
self.r_lazy = r_lazy | ||
|
||
elif a_cd > 2: | ||
self.mode = 'CD' | ||
self.a_cd = a_cd | ||
self._n = 0 | ||
|
||
else: | ||
raise ValueError( | ||
"a_cd must either be None (for regular mode) or a number > 2", | ||
) | ||
raise ValueError('a_cd must either be None (for regular mode) or ' | ||
'a number > 2') | ||
|
||
if restart_strategy in self.__class__.__restarting_strategies__: | ||
self._check_restart_params( | ||
restart_strategy, | ||
min_beta, | ||
s_greedy, | ||
xi_restart, | ||
) | ||
self._check_restart_params(restart_strategy, min_beta, s_greedy, | ||
xi_restart) | ||
self.restart_strategy = restart_strategy | ||
self.min_beta = min_beta | ||
self.s_greedy = s_greedy | ||
self.xi_restart = xi_restart | ||
|
||
else: | ||
raise ValueError( | ||
"Restarting strategy must be one of %s." % | ||
", ".join(self.__class__.__restarting_strategies__) | ||
) | ||
raise ValueError('Restarting strategy must be one of {}.'.format( | ||
', '.join( | ||
self.__class__.__restarting_strategies__))) | ||
self._t_now = 1.0 | ||
self._t_prev = 1.0 | ||
self._delta_0 = None | ||
self._safeguard = False | ||
|
||
def _check_restart_params( | ||
self, | ||
restart_strategy, | ||
min_beta, | ||
s_greedy, | ||
xi_restart, | ||
): | ||
def _check_restart_params(self, restart_strategy, min_beta, s_greedy, | ||
xi_restart): | ||
r""" Check restarting parameters | ||
This method checks that the restarting parameters are set and satisfy | ||
|
@@ -346,23 +332,24 @@ def _check_restart_params( | |
When a parameter that should be set isn't or doesn't verify the | ||
correct assumptions. | ||
""" | ||
|
||
if restart_strategy is None: | ||
return True | ||
|
||
if self.mode != 'regular': | ||
raise ValueError( | ||
"Restarting strategies can only be used with regular mode." | ||
) | ||
greedy_params_check = ( | ||
min_beta is None or s_greedy is None or s_greedy <= 1 | ||
) | ||
raise ValueError('Restarting strategies can only be used with ' | ||
'regular mode.') | ||
|
||
greedy_params_check = (min_beta is None or s_greedy is None or | ||
s_greedy <= 1) | ||
|
||
if restart_strategy == 'greedy' and greedy_params_check: | ||
raise ValueError( | ||
"You need a min_beta and an s_greedy > 1 for greedy restart." | ||
) | ||
raise ValueError('You need a min_beta and an s_greedy > 1 for ' | ||
'greedy restart.') | ||
|
||
if xi_restart is None or xi_restart >= 1: | ||
raise ValueError( | ||
"You need a xi_restart < 1 for restart." | ||
) | ||
raise ValueError('You need a xi_restart < 1 for restart.') | ||
|
||
return True | ||
|
||
def is_restart(self, z_old, x_new, x_old): | ||
|
@@ -393,18 +380,22 @@ def is_restart(self, z_old, x_new, x_old): | |
""" | ||
if self.restart_strategy is None: | ||
return False | ||
|
||
criterion = np.vdot(z_old - x_new, x_new - x_old) >= 0 | ||
|
||
if criterion: | ||
if 'adaptive' in self.restart_strategy: | ||
self.r_lazy *= self.xi_restart | ||
if self.restart_strategy in ['adaptive-ii', 'adaptive-2']: | ||
self._t_now = 1 | ||
|
||
if self.restart_strategy == 'greedy': | ||
cur_delta = np.linalg.norm(x_new - x_old) | ||
if self._delta_0 is None: | ||
self._delta_0 = self.s_greedy * cur_delta | ||
else: | ||
self._safeguard = cur_delta >= self._delta_0 | ||
|
||
return criterion | ||
|
||
def update_beta(self, beta): | ||
|
@@ -422,9 +413,11 @@ def update_beta(self, beta): | |
------- | ||
float: the new value for the beta parameter | ||
""" | ||
|
||
if self._safeguard: | ||
beta *= self.xi_restart | ||
beta = max(beta, self.min_beta) | ||
|
||
return beta | ||
|
||
def update_lambda(self, *args, **kwargs): | ||
|
@@ -441,12 +434,17 @@ def update_lambda(self, *args, **kwargs): | |
Implements steps 3 and 4 from algoritm 10.7 in [B2011]_ | ||
""" | ||
|
||
if self.restart_strategy == 'greedy': | ||
return 2 | ||
|
||
# Steps 3 and 4 from alg.10.7. | ||
self._t_prev = self._t_now | ||
|
||
if self.mode == 'regular': | ||
self._t_now = (self.p_lazy + np.sqrt(self.r_lazy * self._t_prev ** 2 + self.q_lazy)) * 0.5 | ||
self._t_now = (self.p_lazy + np.sqrt(self.r_lazy * | ||
self._t_prev ** 2 + self.q_lazy)) * 0.5 | ||
|
||
elif self.mode == 'CD': | ||
self._t_now = (self._n + self.a_cd - 1) / self.a_cd | ||
self._n += 1 | ||
|
@@ -538,7 +536,7 @@ def __init__(self, x, grad, prox, cost='auto', beta_param=1.0, | |
else: | ||
self._check_param_update(lambda_update) | ||
self._lambda_update = lambda_update | ||
self._is_restart = lambda *args, **kwargs:False | ||
self._is_restart = lambda *args, **kwargs: False | ||
|
||
# Automatically run the algorithm | ||
if auto_iterate: | ||
|
@@ -688,8 +686,8 @@ def __init__(self, x, grad, prox_list, cost='auto', gamma_param=1.0, | |
self._x_old = np.copy(x) | ||
|
||
# Set the algorithm operators | ||
(self._check_operator(operator) for operator in [grad, cost] | ||
+ prox_list) | ||
(self._check_operator(operator) for operator in [grad, cost] + | ||
prox_list) | ||
self._grad = grad | ||
self._prox_list = np.array(prox_list) | ||
self._linear = linear | ||
|
@@ -910,7 +908,7 @@ class Condat(SetUp): | |
""" | ||
|
||
def __init__(self, x, y, grad, prox, prox_dual, linear=None, cost='auto', | ||
reweight=None, rho=0.5, sigma=1.0, tau=1.0, rho_update=None, | ||
reweight=None, rho=0.5, sigma=1.0, tau=1.0, rho_update=None, | ||
sigma_update=None, tau_update=None, auto_iterate=True, | ||
max_iter=150, n_rewightings=1, metric_call_period=5, | ||
metrics={}): | ||
|
@@ -1070,6 +1068,7 @@ def retrieve_outputs(self): | |
metrics[obs.name] = obs.retrieve_metrics() | ||
self.metrics = metrics | ||
|
||
|
||
class POGM(SetUp): | ||
r"""Proximal Optimised Gradient Method | ||
|
@@ -1103,28 +1102,13 @@ class POGM(SetUp): | |
Option to automatically begin iterations upon initialisation (default | ||
is 'True') | ||
""" | ||
def __init__( | ||
self, | ||
u, | ||
x, | ||
y, | ||
z, | ||
grad, | ||
prox, | ||
cost='auto', | ||
linear=None, | ||
beta_param=1.0, | ||
sigma_bar=1.0, | ||
auto_iterate=True, | ||
metric_call_period=5, | ||
metrics={}, | ||
): | ||
def __init__(self, u, x, y, z, grad, prox, cost='auto', linear=None, | ||
beta_param=1.0, sigma_bar=1.0, auto_iterate=True, | ||
metric_call_period=5, metrics={}): | ||
|
||
# Set default algorithm properties | ||
super(POGM, self).__init__( | ||
metric_call_period=metric_call_period, | ||
metrics=metrics, | ||
linear=linear, | ||
) | ||
super(POGM, self).__init__(metric_call_period=metric_call_period, | ||
metrics=metrics, linear=linear) | ||
|
||
# set the initial variable values | ||
(self._check_input_data(data) for data in (u, x, y, z)) | ||
|
@@ -1145,7 +1129,7 @@ def __init__( | |
|
||
# Set the algorithm parameters | ||
(self._check_param(param) for param in (beta_param, sigma_bar)) | ||
if not (0 <= sigma_bar <=1): | ||
if not (0 <= sigma_bar <= 1): | ||
raise ValueError('The sigma bar parameter needs to be in [0, 1]') | ||
self._beta = beta_param | ||
self._sigma_bar = sigma_bar | ||
|
@@ -1169,7 +1153,7 @@ def _update(self): | |
""" | ||
# Step 4 from alg. 3 | ||
self._grad.get_grad(self._x_old) | ||
self._u_new = self._x_old - self._beta * self._grad.grad | ||
self._u_new = self._x_old - self._beta * self._grad.grad | ||
|
||
# Step 5 from alg. 3 | ||
self._t_new = 0.5 * (1 + np.sqrt(1 + 4 * self._t_old**2)) | ||
|
@@ -1218,7 +1202,6 @@ def _update(self): | |
self.converge = self.any_convergence_flag() or \ | ||
self._cost_func.get_cost(self._x_new) | ||
|
||
|
||
def iterate(self, max_iter=150): | ||
r"""Iterate | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.