diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 34755927..edf6a9af 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -7,6 +7,7 @@ import copy from functools import wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union +import jax import numpy as np import scipy.linalg @@ -515,6 +516,11 @@ def __init__( self.kernel_ = None self._identifiability_constraints = False + @property + def num_output_features(self) -> int: + """Read-only property returning the number of features returned by the basis.""" + return self._num_output_features + @property def label(self) -> str: return self._label @@ -1205,6 +1211,129 @@ def _get_default_slicing( start_slice += self._num_output_features return split_dict, start_slice + def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict: + r""" + Decompose a feature matrix along a specified axis into a dictionary of sub-arrays based on basis components. + + This method takes a concatenated feature matrix—such as a design matrix generated by `Basis.compute_features` + or coefficients obtained from a fitted GLM—and decomposes it into separate sub-arrays, making it easier to + analyze individual basis components. + + The resulting dictionary maps each basis component to its corresponding sub-array(s). If a basis component has + multiple inputs, the dictionary will further split each component's features into sub-arrays indexed by + input number. + + ### Behavior + - For each additive basis component, the top-level key is the `label` of that basis component. + - If a component has a single input, the corresponding value in the dictionary will be the sub-array for that input. + - If a component has multiple inputs, the corresponding value will be another dictionary, with keys being the + input indices (0, 1, 2, ...) and values being sub-arrays of the features. + + The shape of `x` along the specified axis must match the total number of features that the basis generates, + i.e., `x.shape[axis] == self.num_output_features`. + + Parameters + ---------- + x : NDArray + The input tensor to be split based on the feature slices. This can be a multidimensional array representing + concatenated features, model coefficients, or any other data. The shape of `x` along the specified axis + should match the total number of features generated by the basis. + axis : int, optional + The axis of `x` to be split. Defaults to 1. Typically, this is the feature axis of your design matrix or + model output. + + Returns + ------- + dict[str, NDArray] + A dictionary where: + - Top-level keys are the labels of the basis components. + - Values are sub-arrays of `x` that correspond to each component's features. + If a basis component has multiple inputs, the value will be another dictionary, with keys being input indices + (0, 1, 2, ...) and values being sub-arrays of `x`. + + Raises + ------ + ValueError + If the shape of `x` along the specified axis does not match `self.num_output_features`. + + Notes + ----- + - This method relies on `self._get_feature_slicing()` to determine the slice dictionary for each component. + - It uses `jax.tree_util.tree_map` to apply slicing operations over the feature axis, along with a custom + `is_leaf` function to identify index tuples as leaves. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import BSplineBasis + >>> from nemos.glm import GLM + + >>> # Define an additive basis + >>> basis = ( + ... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") + + ... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2") + ... ) + + >>> # Generate a sample input array and compute features + >>> x1, x2 = np.random.randn(20), np.random.randn(20) + >>> X = basis.compute_features(x1, x2) + + >>> # Split the feature matrix along axis 1 + >>> split_features = basis.split_feature_axis(X, axis=1) + >>> for feature, arr in split_features.items(): + ... print(f"{feature}: shape {arr.shape}") + feature_1: shape (20, 5) + feature_2: shape (20, 6) + + >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: + >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="multi_input", n_basis_input=2) + >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = multi_input_basis.split_feature_axis(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}: ") + ... for input_num, arr in sub_dict.items(): + ... print(f" input number {int(input_num)}: shape {arr.shape}") + multi_input: + input number 0: shape (20, 6) + input number 1: shape (20, 6) + + >>> # the method can be used to decompose the glm coefficients in the various features + >>> counts = np.random.poisson(size=20) + >>> model = GLM().fit(X, counts) + >>> split_coef = basis.split_feature_axis(model.coef_, axis=0) + >>> for feature, coef in split_coef.items(): + ... print(f"{feature}: shape {coef.shape}") + feature_1: shape (5,) + feature_2: shape (6,) + + """ + if x.shape[axis] != self.num_output_features: + raise ValueError("`x.shape[axis]` does not match the expected number of features." + f" `x.shape[axis] == {x.shape[axis]}`, while the expected number " + f"of features is {self.num_output_features}") + # Get the slice dictionary based on predefined feature slicing + slice_dict = self._get_feature_slicing()[0] + + # Helper function to build index tuples for each slice + def build_index_tuple(slice_obj, axis: int, ndim: int): + """Create an index tuple to apply a slice on the given axis.""" + index = [slice(None)] * ndim # Initialize index for all dimensions + index[axis] = slice_obj # Replace the axis with the slice object + return tuple(index) + + # Get the dict for slicing the correct axis + index_dict = jax.tree_util.tree_map(lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict) + + # Custom leaf function to identify index tuples as leaves + def is_leaf(val): + # Check if it's a tuple, length matches ndim, and all elements are slice objects + if isinstance(val, tuple) and len(val) == x.ndim: + return all(isinstance(v, slice) for v in val) + return False + + # Apply the slicing and return the result using the custom leaf function + return jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) + class AdditiveBasis(Basis): """