Skip to content

Commit

Permalink
isort fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pranmod01 committed Oct 18, 2024
1 parent 5d4617b commit a9799a3
Show file tree
Hide file tree
Showing 14 changed files with 21 additions and 33 deletions.
1 change: 0 additions & 1 deletion jax
Submodule jax deleted from 9cf952
2 changes: 1 addition & 1 deletion src/nemos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
styles,
tree_utils,
type_casting,
utils,
utils
)
2 changes: 1 addition & 1 deletion src/nemos/_documentation_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
plot_rates_and_smoothed_counts,
plot_weighted_sum_basis,
run_animation,
tuning_curve_plot,
tuning_curve_plot
)
5 changes: 2 additions & 3 deletions src/nemos/base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from copy import deepcopy
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import jaxopt
from numpy.typing import ArrayLike, NDArray

import jax
import jax.numpy as jnp
import jaxopt
from numpy.typing import ArrayLike, NDArray

from . import solvers, utils, validation
from ._regularizer_builder import AVAILABLE_REGULARIZERS, create_regularizer
Expand Down
3 changes: 1 addition & 2 deletions src/nemos/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from functools import partial
from typing import Any, Literal, Optional

from numpy.typing import ArrayLike, NDArray

import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike, NDArray

from . import type_casting, utils

Expand Down
5 changes: 2 additions & 3 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from functools import wraps
from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jaxopt
from numpy.typing import ArrayLike
from scipy.optimize import root

import jax
import jax.numpy as jnp

from . import observation_models as obs
from . import tree_utils, validation
from .base_regressor import BaseRegressor
Expand Down
3 changes: 1 addition & 2 deletions src/nemos/observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import abc
from typing import Callable, Literal, Union

from numpy.typing import NDArray

import jax
import jax.numpy as jnp
from numpy.typing import NDArray

from . import utils
from .base_class import Base
Expand Down
5 changes: 2 additions & 3 deletions src/nemos/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import abc
from typing import Callable, Tuple, Union

import jaxopt
from numpy.typing import NDArray

import jax
import jax.numpy as jnp
import jaxopt
from numpy.typing import NDArray

from . import tree_utils
from .base_class import Base
Expand Down
5 changes: 2 additions & 3 deletions src/nemos/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from typing import Callable, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np
import scipy.stats as sts
from numpy.typing import NDArray

import jax
import jax.numpy as jnp

from . import convolve, validation
from .pytrees import FeaturePytree

Expand Down
7 changes: 3 additions & 4 deletions src/nemos/solvers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from functools import partial
from typing import Callable, NamedTuple, Optional, Union

from jaxopt import OptStep
from jaxopt._src import loop
from jaxopt.prox import prox_none

import jax
import jax.flatten_util
import jax.numpy as jnp
from jax import grad, jit, lax, random
from jaxopt import OptStep
from jaxopt._src import loop
from jaxopt.prox import prox_none

from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub
from .typing import KeyArrayLike, Pytree
Expand Down
5 changes: 2 additions & 3 deletions src/nemos/type_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from functools import wraps
from typing import Any, Callable, List, Literal, Optional, Type, Union

import jax
import jax.numpy as jnp
import numpy as np
import pynapple as nap
from numpy.typing import NDArray

import jax
import jax.numpy as jnp

from . import tree_utils

_NAP_TIME_PRECISION = 10 ** (-nap.nap_config.time_index_precision)
Expand Down
3 changes: 1 addition & 2 deletions src/nemos/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from typing import Any, Callable, NamedTuple, Tuple, Union

import jaxopt

import jax.numpy as jnp
import jaxopt
from jax._src.typing import ArrayLike

from .pytrees import FeaturePytree
Expand Down
5 changes: 2 additions & 3 deletions src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import warnings
from typing import Any, Callable, List, Literal, Union

import numpy as np
from numpy.typing import NDArray

import jax
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray

from .tree_utils import pytree_map_and_reduce
from .type_casting import is_numpy_array_like, support_pynapple
Expand Down
3 changes: 1 addition & 2 deletions src/nemos/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import warnings
from typing import Any, Optional, Union

from numpy.typing import DTypeLike, NDArray

import jax
import jax.numpy as jnp
from numpy.typing import DTypeLike, NDArray

from .pytrees import FeaturePytree
from .tree_utils import get_valid_multitree, pytree_map_and_reduce
Expand Down

0 comments on commit a9799a3

Please sign in to comment.