From 249b0f7a664d75f3dc59fa7d548379b4eec12793 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 11 Oct 2024 11:27:25 -0400 Subject: [PATCH] linted --- src/nemos/basis.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index edf6a9af..dcab040f 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -7,8 +7,8 @@ import copy from functools import wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union -import jax +import jax import numpy as np import scipy.linalg from numpy.typing import ArrayLike, NDArray @@ -1225,7 +1225,8 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict: ### Behavior - For each additive basis component, the top-level key is the `label` of that basis component. - - If a component has a single input, the corresponding value in the dictionary will be the sub-array for that input. + - If a component has a single input, the corresponding value in the dictionary will be the sub-array + for that input. - If a component has multiple inputs, the corresponding value will be another dictionary, with keys being the input indices (0, 1, 2, ...) and values being sub-arrays of the features. @@ -1248,7 +1249,8 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict: A dictionary where: - Top-level keys are the labels of the basis components. - Values are sub-arrays of `x` that correspond to each component's features. - If a basis component has multiple inputs, the value will be another dictionary, with keys being input indices + If a basis component has multiple inputs, the value will be another dictionary, with keys being + input indices (0, 1, 2, ...) and values being sub-arrays of `x`. Raises @@ -1286,7 +1288,8 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict: feature_2: shape (20, 6) >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="multi_input", n_basis_input=2) + >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, + ... label="multi_input", n_basis_input=2) >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = multi_input_basis.split_feature_axis(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -1308,9 +1311,11 @@ def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict: """ if x.shape[axis] != self.num_output_features: - raise ValueError("`x.shape[axis]` does not match the expected number of features." - f" `x.shape[axis] == {x.shape[axis]}`, while the expected number " - f"of features is {self.num_output_features}") + raise ValueError( + "`x.shape[axis]` does not match the expected number of features." + f" `x.shape[axis] == {x.shape[axis]}`, while the expected number " + f"of features is {self.num_output_features}" + ) # Get the slice dictionary based on predefined feature slicing slice_dict = self._get_feature_slicing()[0] @@ -1322,7 +1327,9 @@ def build_index_tuple(slice_obj, axis: int, ndim: int): return tuple(index) # Get the dict for slicing the correct axis - index_dict = jax.tree_util.tree_map(lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict) + index_dict = jax.tree_util.tree_map( + lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict + ) # Custom leaf function to identify index tuples as leaves def is_leaf(val):