Skip to content

Commit

Permalink
changed typing
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 16, 2024
1 parent b3f801f commit 26499a7
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/nemos/solvers/_svrg_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax
import jax.numpy as jnp
from numpy.typing import NDArray


def _convert_to_float(func):
Expand Down Expand Up @@ -52,22 +53,22 @@ def svrg_optimal_batch_and_stepsize(
Parameters
----------
compute_smoothness_constants : Callable
compute_smoothness_constants :
Function that computes the smoothness constants `l_smooth` and `l_smooth_max` for the problem.
This is problem (loss function) specific.
data : Any
data :
Input data, typically (X, y) for a GLM.
batch_size : Optional[int], default None
batch_size :
The batch size set by the user. If None, it will be calculated.
stepsize : Optional[float], default None
stepsize :
The step size set by the user. If None, it will be calculated.
strong_convexity : Optional[float], default None
strong_convexity :
The strong convexity constant. For L2-regularized losses, this should be the regularization strength.
n_power_iters : Optional[int], default None
n_power_iters :
Maximum number of iterations for the power method when finding the largest eigenvalue.
default_batch_size : int, default 1
default_batch_size :
Default batch size to use if the optimal calculation fails.
default_stepsize : float, default 1e-3
default_stepsize :
Default step size to use if the optimal calculation fails.
Returns
Expand Down Expand Up @@ -145,8 +146,8 @@ def svrg_optimal_batch_and_stepsize(


def glm_softplus_poisson_l_max_and_l(
*data: jnp.ndarray, n_power_iters: Optional[int] = 20
) -> Tuple[float, float]:
*data: NDArray, n_power_iters: Optional[int] = 20
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Calculate smoothness constants for a Poisson GLM with a softplus inverse link function.
Expand Down Expand Up @@ -180,13 +181,14 @@ def glm_softplus_poisson_l_max_and_l(


def _glm_softplus_poisson_l_smooth_multiply(
X: jnp.ndarray, y: jnp.ndarray, v: jnp.ndarray, batch_size: int
X: NDArray, y: NDArray, v: NDArray, batch_size: int
):
"""
Multiply vector `v` with the matrix X.T @ D @ X without forming it explicitly.
This method estimates the multiplication by calculating the Hessian of the loss.
It is efficient for situations where X can fit in memory.
If batch_size is provided, the computation will be done by slicing the array.
Parameters
----------
Expand All @@ -212,7 +214,7 @@ def _glm_softplus_poisson_l_smooth_multiply(


def _glm_softplus_poisson_l_smooth_with_power_iteration(
X: jnp.ndarray, y: jnp.ndarray, n_power_iters: int = 20, batch_size: Optional[int] = None
X: NDArray, y: NDArray, n_power_iters: int = 20, batch_size: Optional[int] = None
):
"""
Compute the largest eigenvalue of X.T @ D @ X using the power method.
Expand Down Expand Up @@ -262,7 +264,7 @@ def _glm_softplus_poisson_l_smooth_with_power_iteration(


def _glm_softplus_poisson_l_smooth(
X: jnp.ndarray, y: jnp.ndarray, n_power_iters: Optional[int] = None
X: NDArray, y: NDArray, n_power_iters: Optional[int] = None
) -> jnp.ndarray:
"""
Calculate the smoothness constant `L` for a Poisson GLM with softplus inverse link.
Expand All @@ -272,11 +274,11 @@ def _glm_softplus_poisson_l_smooth(
Parameters
----------
X : jnp.ndarray
X :
Input data matrix (N x d).
y : jnp.ndarray
y :
Output data vector (N,).
n_power_iters : Optional[int], default None
n_power_iters :
Number of power iterations to use when finding the largest eigenvalue. If None,
the eigenvalue is calculated directly.
Expand All @@ -294,7 +296,7 @@ def _glm_softplus_poisson_l_smooth(
return _glm_softplus_poisson_l_smooth_with_power_iteration(X, y, n_power_iters)


def _glm_softplus_poisson_l_smooth_max(X: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
def _glm_softplus_poisson_l_smooth_max(X: NDArray, y: NDArray) -> NDArray:
"""
Calculate the maximum smoothness constant `L_max` for individual observations.
Expand Down

0 comments on commit 26499a7

Please sign in to comment.