diff --git a/jax b/jax deleted file mode 160000 index 9cf952a5..00000000 --- a/jax +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9cf952a535518da59cdcecc9145dba287beddca2 diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py index 97c5b3db..aedd05c0 100644 --- a/src/nemos/__init__.py +++ b/src/nemos/__init__.py @@ -14,5 +14,5 @@ styles, tree_utils, type_casting, - utils, + utils ) diff --git a/src/nemos/_documentation_utils/__init__.py b/src/nemos/_documentation_utils/__init__.py index 3cd63e0e..1c64a43a 100644 --- a/src/nemos/_documentation_utils/__init__.py +++ b/src/nemos/_documentation_utils/__init__.py @@ -19,5 +19,5 @@ plot_rates_and_smoothed_counts, plot_weighted_sum_basis, run_animation, - tuning_curve_plot, + tuning_curve_plot ) diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index ba2782ed..e4a425ce 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -9,11 +9,10 @@ from copy import deepcopy from typing import Any, Dict, NamedTuple, Optional, Tuple, Union -import jaxopt -from numpy.typing import ArrayLike, NDArray - import jax import jax.numpy as jnp +import jaxopt +from numpy.typing import ArrayLike, NDArray from . import solvers, utils, validation from ._regularizer_builder import AVAILABLE_REGULARIZERS, create_regularizer diff --git a/src/nemos/convolve.py b/src/nemos/convolve.py index e9a8d1fc..6a6294c3 100644 --- a/src/nemos/convolve.py +++ b/src/nemos/convolve.py @@ -8,10 +8,9 @@ from functools import partial from typing import Any, Literal, Optional -from numpy.typing import ArrayLike, NDArray - import jax import jax.numpy as jnp +from numpy.typing import ArrayLike, NDArray from . import type_casting, utils diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 70b4ee9e..44cbb4d4 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -7,13 +7,12 @@ from functools import wraps from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Union +import jax +import jax.numpy as jnp import jaxopt from numpy.typing import ArrayLike from scipy.optimize import root -import jax -import jax.numpy as jnp - from . import observation_models as obs from . import tree_utils, validation from .base_regressor import BaseRegressor diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py index de0b8981..9d683ae1 100644 --- a/src/nemos/observation_models.py +++ b/src/nemos/observation_models.py @@ -3,10 +3,9 @@ import abc from typing import Callable, Literal, Union -from numpy.typing import NDArray - import jax import jax.numpy as jnp +from numpy.typing import NDArray from . import utils from .base_class import Base diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index c2cc75d6..6d6cf0bd 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -9,11 +9,10 @@ import abc from typing import Callable, Tuple, Union -import jaxopt -from numpy.typing import NDArray - import jax import jax.numpy as jnp +import jaxopt +from numpy.typing import NDArray from . import tree_utils from .base_class import Base diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py index 92c64c12..e698e702 100644 --- a/src/nemos/simulation.py +++ b/src/nemos/simulation.py @@ -2,13 +2,12 @@ from typing import Callable, Tuple, Union +import jax +import jax.numpy as jnp import numpy as np import scipy.stats as sts from numpy.typing import NDArray -import jax -import jax.numpy as jnp - from . import convolve, validation from .pytrees import FeaturePytree diff --git a/src/nemos/solvers.py b/src/nemos/solvers.py index 779d24d0..4c060609 100644 --- a/src/nemos/solvers.py +++ b/src/nemos/solvers.py @@ -1,14 +1,13 @@ from functools import partial from typing import Callable, NamedTuple, Optional, Union -from jaxopt import OptStep -from jaxopt._src import loop -from jaxopt.prox import prox_none - import jax import jax.flatten_util import jax.numpy as jnp from jax import grad, jit, lax, random +from jaxopt import OptStep +from jaxopt._src import loop +from jaxopt.prox import prox_none from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub from .typing import KeyArrayLike, Pytree diff --git a/src/nemos/type_casting.py b/src/nemos/type_casting.py index 02a73dd8..8a5522b8 100644 --- a/src/nemos/type_casting.py +++ b/src/nemos/type_casting.py @@ -12,13 +12,12 @@ from functools import wraps from typing import Any, Callable, List, Literal, Optional, Type, Union +import jax +import jax.numpy as jnp import numpy as np import pynapple as nap from numpy.typing import NDArray -import jax -import jax.numpy as jnp - from . import tree_utils _NAP_TIME_PRECISION = 10 ** (-nap.nap_config.time_index_precision) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index fa86ca82..dd9bc5a6 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -2,9 +2,8 @@ from typing import Any, Callable, NamedTuple, Tuple, Union -import jaxopt - import jax.numpy as jnp +import jaxopt from jax._src.typing import ArrayLike from .pytrees import FeaturePytree diff --git a/src/nemos/utils.py b/src/nemos/utils.py index f457e4f1..87ad0472 100644 --- a/src/nemos/utils.py +++ b/src/nemos/utils.py @@ -3,11 +3,10 @@ import warnings from typing import Any, Callable, List, Literal, Union -import numpy as np -from numpy.typing import NDArray - import jax import jax.numpy as jnp +import numpy as np +from numpy.typing import NDArray from .tree_utils import pytree_map_and_reduce from .type_casting import is_numpy_array_like, support_pynapple diff --git a/src/nemos/validation.py b/src/nemos/validation.py index c3ff63bf..9dd0853f 100644 --- a/src/nemos/validation.py +++ b/src/nemos/validation.py @@ -3,10 +3,9 @@ import warnings from typing import Any, Optional, Union -from numpy.typing import DTypeLike, NDArray - import jax import jax.numpy as jnp +from numpy.typing import DTypeLike, NDArray from .pytrees import FeaturePytree from .tree_utils import get_valid_multitree, pytree_map_and_reduce