Skip to content

Commit

Permalink
Merge pull request #207 from bagibence/auto_stepsize_svrg
Browse files Browse the repository at this point in the history
Automatic step sizes for SVRG
  • Loading branch information
BalzaniEdoardo authored Oct 28, 2024
2 parents 0ac4b81 + 028b26a commit b85f408
Show file tree
Hide file tree
Showing 12 changed files with 1,643 additions and 23 deletions.
Binary file added docs/assets/poisson_model_calc_stepsize.pdf
Binary file not shown.
71 changes: 66 additions & 5 deletions src/nemos/base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import abc
import inspect
import warnings
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -228,7 +229,7 @@ def solver_name(self, solver_name: str):
if solver_name not in self._regularizer.allowed_solvers:
raise ValueError(
f"The solver: {solver_name} is not allowed for "
f"{self._regularizer.__class__.__name__} regularizaration. Allowed solvers are "
f"{self._regularizer.__class__.__name__} regularization. Allowed solvers are "
f"{self._regularizer.allowed_solvers}."
)
self._solver_name = solver_name
Expand Down Expand Up @@ -270,7 +271,9 @@ def _check_solver_kwargs(solver_class, solver_kwargs):
f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for {solver_class.__name__}!"
)

def instantiate_solver(self, *args) -> BaseRegressor:
def instantiate_solver(
self, *args, solver_kwargs: Optional[dict] = None
) -> BaseRegressor:
"""
Instantiate the solver with the provided loss function.
Expand All @@ -289,6 +292,9 @@ def instantiate_solver(self, *args) -> BaseRegressor:
*args:
Positional arguments for the jaxopt `solver.run` method, e.g. the regularizing
strength for proximal gradient methods.
solver_kwargs:
Optional dictionary with the solver kwargs.
If nothing is provided, it defaults to self.solver_kwargs.
Returns
-------
Expand All @@ -299,7 +305,7 @@ def instantiate_solver(self, *args) -> BaseRegressor:
if self.solver_name not in self.regularizer.allowed_solvers:
raise ValueError(
f"The solver: {self.solver_name} is not allowed for "
f"{self._regularizer.__class__.__name__} regularizaration. Allowed solvers are "
f"{self._regularizer.__class__.__name__} regularization. Allowed solvers are "
f"{self._regularizer.allowed_solvers}."
)

Expand All @@ -313,8 +319,9 @@ def instantiate_solver(self, *args) -> BaseRegressor:
else:
loss = self._predict_and_compute_loss

# copy dictionary of kwargs to avoid modifying user settings
solver_kwargs = deepcopy(self.solver_kwargs)
if solver_kwargs is None:
# copy dictionary of kwargs to avoid modifying user settings
solver_kwargs = deepcopy(self.solver_kwargs)

# check that the loss is Callable
utils.assert_is_callable(loss, "loss")
Expand Down Expand Up @@ -600,3 +607,57 @@ def _get_solver_class(solver_name: str):
)

return solver_class

def optimize_solver_params(self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray) -> dict:
"""
Compute and update solver parameters with optimal defaults if available.
This method checks the current solver configuration and, if an optimal
configuration is known for the given model parameters, computes the optimal
batch size, step size, and other hyperparameters to ensure faster convergence.
Parameters
----------
X :
Input data used to compute smoothness and strong convexity constants.
y :
Target values used in conjunction with X for the same purpose.
Returns
-------
:
A dictionary containing the solver parameters, updated with optimal defaults
where applicable.
"""
# Start with a copy of the existing solver parameters
new_solver_kwargs = self.solver_kwargs.copy()

# get the model specific configs
compute_defaults, compute_l_smooth, strong_convexity = (
self.get_optimal_solver_params_config()
)
if compute_defaults and compute_l_smooth:
# Check if the user has provided batch size or stepsize, or else use None
batch_size = new_solver_kwargs.get("batch_size", None)
stepsize = new_solver_kwargs.get("stepsize", None)

# Compute the optimal batch size and stepsize based on smoothness, strong convexity, etc.
new_params = compute_defaults(
compute_l_smooth,
X,
y,
batch_size=batch_size,
stepsize=stepsize,
strong_convexity=strong_convexity,
)

# Update the solver parameters with the computed optimal values
new_solver_kwargs.update(new_params)

return new_solver_kwargs

@abstractmethod
def get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
pass
77 changes: 68 additions & 9 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .initialize_regressor import initialize_intercept_matching_mean_rate
from .pytrees import FeaturePytree
from .regularizer import GroupLasso, Lasso, Regularizer, Ridge
from .solvers._compute_defaults import glm_compute_optimal_stepsize_configs
from .type_casting import jnp_asarray_if, support_pynapple
from .typing import DESIGN_INPUT_TYPE

Expand Down Expand Up @@ -55,10 +56,36 @@ class GLM(BaseRegressor):
| Regularizer | Default Solver | Available Solvers |
| ------------- | ---------------- | ----------------------------------------------------------- |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Lasso | ProximalGradient | ProximalGradient |
| GroupLasso | ProximalGradient | ProximalGradient |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Lasso | ProximalGradient | ProximalGradient, ProxSVRG |
| GroupLasso | ProximalGradient | ProximalGradient, ProxSVRG |
**Fitting Large Models**
For very large models, you may consider using the Stochastic Variance Reduced Gradient
([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant
([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver,
which take advantage of batched computation. You can change the solver by passing
`"SVRG"` as `solver_name` at model initialization.
The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize`
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.
To assist with this, for certain GLM configurations, we provide `batch_size` and `stepsize` default
values that are theoretically guaranteed to ensure fast convergence.
Below is a list of the configurations for which we can provide guaranteed default hyperparameters:
| GLM / PopulationGLM Configuration | Stepsize | Batch Size |
| --------------------------------- | :------: | :---------: |
| Poisson + soft-plus + UnRegularized | ✅ | ❌ |
| Poisson + soft-plus + Ridge | ✅ | ✅ |
| Poisson + soft-plus + Lasso | ✅ | ❌ |
| Poisson + soft-plus + GroupLasso | ✅ | ❌ |
Parameters
----------
Expand Down Expand Up @@ -968,8 +995,10 @@ def initialize_state(
)
self.regularizer.mask = jnp.ones((1, data.shape[1]))

opt_solver_kwargs = self.optimize_solver_params(data, y)

# set up the solver init/run/update attrs
self.instantiate_solver()
self.instantiate_solver(solver_kwargs=opt_solver_kwargs)

opt_state = self.solver_init_state(init_params, data, y)
return opt_state
Expand Down Expand Up @@ -1066,6 +1095,10 @@ def update(

return opt_step

def get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
return glm_compute_optimal_stepsize_configs(self)


class PopulationGLM(GLM):
"""
Expand All @@ -1081,10 +1114,36 @@ class PopulationGLM(GLM):
| Regularizer | Default Solver | Available Solvers |
| ------------- | ---------------- | ----------------------------------------------------------- |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Lasso | ProximalGradient | ProximalGradient |
| GroupLasso | ProximalGradient | ProximalGradient |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Lasso | ProximalGradient | ProximalGradient, ProxSVRG |
| GroupLasso | ProximalGradient | ProximalGradient, ProxSVRG |
**Fitting Large Models**
For very large models, you may consider using the Stochastic Variance Reduced Gradient
([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant
([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver,
which take advantage of batched computation. You can change the solver by passing
`"SVRG"` or `"ProxSVRG"` as `solver_name` at model initialization.
The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize`
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.
To assist with this, for certain GLM configurations, we provide `batch_size` and `stepsize` default
values that are theoretically guaranteed to ensure fast convergence.
Below is a list of the configurations for which we can provide guaranteed hyperparameters:
| GLM / PopulationGLM Configuration | Stepsize | Batch Size |
| --------------------------------- | :------: | :---------: |
| Poisson + soft-plus + UnRegularized | ✅ | ❌ |
| Poisson + soft-plus + Ridge | ✅ | ✅ |
| Poisson + soft-plus + Lasso | ✅ | ❌ |
| Poisson + soft-plus + GroupLasso | ✅ | ❌ |
Parameters
----------
Expand Down
5 changes: 5 additions & 0 deletions src/nemos/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._svrg import SVRG, ProxSVRG
from ._svrg_defaults import (
glm_softplus_poisson_l_max_and_l,
svrg_optimal_batch_and_stepsize,
)
70 changes: 70 additions & 0 deletions src/nemos/solvers/_compute_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

import jax

from ..observation_models import PoissonObservations
from ..regularizer import Ridge
from ._svrg_defaults import (
glm_softplus_poisson_l_max_and_l,
svrg_optimal_batch_and_stepsize,
)

if TYPE_CHECKING:
from ..glm import GLM, PopulationGLM


def glm_compute_optimal_stepsize_configs(
model: Union[GLM, PopulationGLM]
) -> Tuple[Optional[Callable], Optional[Callable], Optional[float]]:
"""
Compute configuration functions for optimal step size selection based on the model.
This function returns a tuple of three elements that are used for configuring the
optimal step size and batch size for variance reduced gradient (SVRG and
ProxSVRG) algorithms. If the model is configured with specific solver names,
the appropriate computation functions are returned. Additionally, it determines the
smoothness and strong convexity constants based on the model's observation and regularizer.
Parameters
----------
model :
The generalized linear model object for which the optimal step size and batch
configuration need to be computed.
Returns
-------
compute_optimal_params :
A function to compute the optimal batch size and step size if the model
is configured with the SVRG or ProxSVRG solver, None otherwise.
compute_smoothness :
A function to compute the smoothness constant of the loss function if the
observation model uses a softplus inverse link function and is a Poisson
observation model, None otherwise.
strong_convexity :
The strong convexity constant of the loss function if the model has a
Ridge regularizer. If the model does not have a Ridge regularizer, this
value will be None.
"""
# initialize funcs and strong convexity constant
compute_optimal_params = None
compute_smoothness = None
strong_convexity = (
None if not isinstance(model.regularizer, Ridge) else model.regularizer_strength
)

# look-up table for selecting the optimal step and batch
if model.solver_name in ("SVRG", "ProxSVRG"):
compute_optimal_params = svrg_optimal_batch_and_stepsize

# get the smoothness parameter compute function
if model.observation_model.inverse_link_function is jax.nn.softplus and isinstance(
model.observation_model, PoissonObservations
):
compute_smoothness = glm_softplus_poisson_l_max_and_l

return compute_optimal_params, compute_smoothness, strong_convexity
16 changes: 8 additions & 8 deletions src/nemos/solvers.py → src/nemos/solvers/_svrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from jaxopt._src import loop
from jaxopt.prox import prox_none

from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub
from .typing import KeyArrayLike, Pytree
from ..tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub
from ..typing import KeyArrayLike, Pytree


class SVRGState(NamedTuple):
Expand Down Expand Up @@ -207,7 +207,7 @@ def _inner_loop_param_update_step(
# gradient of f_{i_k} at x_{k} in the pseudocode of Gower et al. 2020
minibatch_grad_at_current_params = self.loss_gradient(params, *args)
# gradient on batch_{i_k} evaluated at the anchor point
# gradient of f_{i_k} at x_{x} in the pseudocode of Gower et al. 2020
# gradient of f_{i_k} at x_{k} in the pseudocode of Gower et al. 2020
minibatch_grad_at_reference_point = self.loss_gradient(reference_point, *args)

# SVRG gradient estimate
Expand Down Expand Up @@ -575,7 +575,7 @@ def inner_loop_body(_, carry):
@staticmethod
def _error(x, x_prev, stepsize):
"""
Calculate the magnitude of the update relative to the parameters.
Calculate the magnitude of the update relative to the stepsize.
Used for terminating the algorithm if a certain tolerance is reached.
Params
Expand All @@ -589,15 +589,15 @@ def _error(x, x_prev, stepsize):
-------
Scaled update magnitude.
"""
# stepsize is an argument to be consistent with jaxopt
return tree_l2_norm(tree_sub(x, x_prev)) / tree_l2_norm(x_prev)
return tree_l2_norm(tree_sub(x, x_prev)) / stepsize


class SVRG(ProxSVRG):
"""
SVRG solver
SVRG solver.
Equivalent to ProxSVRG with prox as the identity function and hyperparams_prox=None.
This solver implements "Algorithm 3" of [1]. Equivalent to ProxSVRG with prox as the identity
function and hyperparams_prox=None.
Attributes
----------
Expand Down
Loading

0 comments on commit b85f408

Please sign in to comment.