Skip to content

Commit

Permalink
linted
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 11, 2024
1 parent 1397b6e commit 249b0f7
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import copy
from functools import wraps
from typing import Callable, Generator, Literal, Optional, Tuple, Union
import jax

import jax
import numpy as np
import scipy.linalg
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -1225,7 +1225,8 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict:
### Behavior
- For each additive basis component, the top-level key is the `label` of that basis component.
- If a component has a single input, the corresponding value in the dictionary will be the sub-array for that input.
- If a component has a single input, the corresponding value in the dictionary will be the sub-array
for that input.
- If a component has multiple inputs, the corresponding value will be another dictionary, with keys being the
input indices (0, 1, 2, ...) and values being sub-arrays of the features.
Expand All @@ -1248,7 +1249,8 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict:
A dictionary where:
- Top-level keys are the labels of the basis components.
- Values are sub-arrays of `x` that correspond to each component's features.
If a basis component has multiple inputs, the value will be another dictionary, with keys being input indices
If a basis component has multiple inputs, the value will be another dictionary, with keys being
input indices
(0, 1, 2, ...) and values being sub-arrays of `x`.
Raises
Expand Down Expand Up @@ -1286,7 +1288,8 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict:
feature_2: shape (20, 6)
>>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested:
>>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="multi_input", n_basis_input=2)
>>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10,
... label="multi_input", n_basis_input=2)
>>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2))
>>> split_features_multi = multi_input_basis.split_feature_axis(X_multi, axis=1)
>>> for feature, sub_dict in split_features_multi.items():
Expand All @@ -1308,9 +1311,11 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict:
"""
if x.shape[axis] != self.num_output_features:
raise ValueError("`x.shape[axis]` does not match the expected number of features."
f" `x.shape[axis] == {x.shape[axis]}`, while the expected number "
f"of features is {self.num_output_features}")
raise ValueError(
"`x.shape[axis]` does not match the expected number of features."
f" `x.shape[axis] == {x.shape[axis]}`, while the expected number "
f"of features is {self.num_output_features}"
)
# Get the slice dictionary based on predefined feature slicing
slice_dict = self._get_feature_slicing()[0]

Expand All @@ -1322,7 +1327,9 @@ def build_index_tuple(slice_obj, axis: int, ndim: int):
return tuple(index)

# Get the dict for slicing the correct axis
index_dict = jax.tree_util.tree_map(lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict)
index_dict = jax.tree_util.tree_map(
lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict
)

# Custom leaf function to identify index tuples as leaves
def is_leaf(val):
Expand Down

0 comments on commit 249b0f7

Please sign in to comment.