Skip to content

Commit

Permalink
linters
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 21, 2024
1 parent 8a08d3a commit 3dac13a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
6 changes: 4 additions & 2 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from . import tree_utils, validation
from .base_regressor import BaseRegressor
from .exceptions import NotFittedError
from .initialize_regressor import initialize_intercept_matching_mean_rate
from .pytrees import FeaturePytree
from .regularizer import GroupLasso, Lasso, Regularizer, Ridge
from .type_casting import jnp_asarray_if, support_pynapple
from .typing import DESIGN_INPUT_TYPE
from .initialize_regressor import initialize_intercept_matching_mean_rate

ModelParams = Tuple[jnp.ndarray, jnp.ndarray]

Expand Down Expand Up @@ -535,7 +535,9 @@ def _initialize_parameters(
else:
data = X

initial_intercept= initialize_intercept_matching_mean_rate(self.observation_model.inverse_link_function, y)
initial_intercept = initialize_intercept_matching_mean_rate(
self.observation_model.inverse_link_function, y
)

# Initialize parameters
init_params = (
Expand Down
34 changes: 20 additions & 14 deletions src/nemos/initialize_regressor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import jax.numpy as jnp
import jax
from scipy.optimize import root_scalar
from typing import Callable
from numpy.typing import ArrayLike

import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike
from scipy.optimize import root_scalar

# dictionary of known inverse link functions.
INVERSE_FUNCS = {
jnp.exp: jnp.log,
jax.nn.softplus: lambda x: jnp.log(jnp.exp(x) - 1.),
jax.nn.softplus: lambda x: jnp.log(jnp.exp(x) - 1.0),
}


def scalar_root_find_elementwise(func: Callable, args: ArrayLike, x0: ArrayLike) -> jnp.ndarray:
def scalar_root_find_elementwise(
func: Callable, args: ArrayLike, x0: ArrayLike
) -> jnp.ndarray:
"""
Find roots of a scalar function.
Expand Down Expand Up @@ -40,17 +42,19 @@ def scalar_root_find_elementwise(func: Callable, args: ArrayLike, x0: ArrayLike)
"""
opts = [root_scalar(func, arg, x0=x, method="secant") for arg, x in zip(args, x0)]

if not all(jnp.abs(func(opt.root, args[i])) < 10 ** -4 for i, opt in enumerate(opts)):
if not all(jnp.abs(func(opt.root, args[i])) < 10**-4 for i, opt in enumerate(opts)):
raise ValueError(
"Could not set the initial intercept as the inverse of the firing rate for "
"the provided link function. "
"Please, provide initial parameters instead!"
)
"Could not set the initial intercept as the inverse of the firing rate for "
"the provided link function. "
"Please, provide initial parameters instead!"
)

return jnp.array([opt.root for opt in opts])


def initialize_intercept_matching_mean_rate(inverse_link_function: Callable, y: jnp.ndarray) -> jnp.ndarray:
def initialize_intercept_matching_mean_rate(
inverse_link_function: Callable, y: jnp.ndarray
) -> jnp.ndarray:
"""
Compute the initial intercept term for a regression models.
Expand Down Expand Up @@ -81,8 +85,10 @@ def initialize_intercept_matching_mean_rate(inverse_link_function: Callable, y:
if analytical_inv:
out = analytical_inv(means)
if jnp.any(jnp.isnan(out)):
raise ValueError("Could not set the initial intercept as the inverse of the firing rate for "
"the provided link funciton. The mean firing rate assumes negative values.")
raise ValueError(
"Could not set the initial intercept as the inverse of the firing rate for "
"the provided link funciton. The mean firing rate assumes negative values."
)
return out

def func(x, mean_x):
Expand Down

0 comments on commit 3dac13a

Please sign in to comment.