Skip to content

Commit

Permalink
improved description of the split method
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 11, 2024
1 parent 3dffece commit 8f1b832
Showing 1 changed file with 129 additions and 0 deletions.
129 changes: 129 additions & 0 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
from functools import wraps
from typing import Callable, Generator, Literal, Optional, Tuple, Union
import jax

import numpy as np
import scipy.linalg
Expand Down Expand Up @@ -515,6 +516,11 @@ def __init__(
self.kernel_ = None
self._identifiability_constraints = False

@property
def num_output_features(self) -> int:
"""Read-only property returning the number of features returned by the basis."""
return self._num_output_features

@property
def label(self) -> str:
return self._label
Expand Down Expand Up @@ -1205,6 +1211,129 @@ def _get_default_slicing(
start_slice += self._num_output_features
return split_dict, start_slice

def split_feature_axis(self, x: NDArray, axis: int = 1) -> dict:
r"""
Decompose a feature matrix along a specified axis into a dictionary of sub-arrays based on basis components.
This method takes a concatenated feature matrix—such as a design matrix generated by `Basis.compute_features`
or coefficients obtained from a fitted GLM—and decomposes it into separate sub-arrays, making it easier to
analyze individual basis components.
The resulting dictionary maps each basis component to its corresponding sub-array(s). If a basis component has
multiple inputs, the dictionary will further split each component's features into sub-arrays indexed by
input number.
### Behavior
- For each additive basis component, the top-level key is the `label` of that basis component.
- If a component has a single input, the corresponding value in the dictionary will be the sub-array for that input.
- If a component has multiple inputs, the corresponding value will be another dictionary, with keys being the
input indices (0, 1, 2, ...) and values being sub-arrays of the features.
The shape of `x` along the specified axis must match the total number of features that the basis generates,
i.e., `x.shape[axis] == self.num_output_features`.
Parameters
----------
x : NDArray
The input tensor to be split based on the feature slices. This can be a multidimensional array representing
concatenated features, model coefficients, or any other data. The shape of `x` along the specified axis
should match the total number of features generated by the basis.
axis : int, optional
The axis of `x` to be split. Defaults to 1. Typically, this is the feature axis of your design matrix or
model output.
Returns
-------
dict[str, NDArray]
A dictionary where:
- Top-level keys are the labels of the basis components.
- Values are sub-arrays of `x` that correspond to each component's features.
If a basis component has multiple inputs, the value will be another dictionary, with keys being input indices
(0, 1, 2, ...) and values being sub-arrays of `x`.
Raises
------
ValueError
If the shape of `x` along the specified axis does not match `self.num_output_features`.
Notes
-----
- This method relies on `self._get_feature_slicing()` to determine the slice dictionary for each component.
- It uses `jax.tree_util.tree_map` to apply slicing operations over the feature axis, along with a custom
`is_leaf` function to identify index tuples as leaves.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import BSplineBasis
>>> from nemos.glm import GLM
>>> # Define an additive basis
>>> basis = (
... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") +
... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2")
... )
>>> # Generate a sample input array and compute features
>>> x1, x2 = np.random.randn(20), np.random.randn(20)
>>> X = basis.compute_features(x1, x2)
>>> # Split the feature matrix along axis 1
>>> split_features = basis.split_feature_axis(X, axis=1)
>>> for feature, arr in split_features.items():
... print(f"{feature}: shape {arr.shape}")
feature_1: shape (20, 5)
feature_2: shape (20, 6)
>>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested:
>>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="multi_input", n_basis_input=2)
>>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2))
>>> split_features_multi = multi_input_basis.split_feature_axis(X_multi, axis=1)
>>> for feature, sub_dict in split_features_multi.items():
... print(f"{feature}: ")
... for input_num, arr in sub_dict.items():
... print(f" input number {int(input_num)}: shape {arr.shape}")
multi_input:
input number 0: shape (20, 6)
input number 1: shape (20, 6)
>>> # the method can be used to decompose the glm coefficients in the various features
>>> counts = np.random.poisson(size=20)
>>> model = GLM().fit(X, counts)
>>> split_coef = basis.split_feature_axis(model.coef_, axis=0)
>>> for feature, coef in split_coef.items():
... print(f"{feature}: shape {coef.shape}")
feature_1: shape (5,)
feature_2: shape (6,)
"""
if x.shape[axis] != self.num_output_features:
raise ValueError("`x.shape[axis]` does not match the expected number of features."
f" `x.shape[axis] == {x.shape[axis]}`, while the expected number "
f"of features is {self.num_output_features}")
# Get the slice dictionary based on predefined feature slicing
slice_dict = self._get_feature_slicing()[0]

# Helper function to build index tuples for each slice
def build_index_tuple(slice_obj, axis: int, ndim: int):
"""Create an index tuple to apply a slice on the given axis."""
index = [slice(None)] * ndim # Initialize index for all dimensions
index[axis] = slice_obj # Replace the axis with the slice object
return tuple(index)

# Get the dict for slicing the correct axis
index_dict = jax.tree_util.tree_map(lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict)

# Custom leaf function to identify index tuples as leaves
def is_leaf(val):
# Check if it's a tuple, length matches ndim, and all elements are slice objects
if isinstance(val, tuple) and len(val) == x.ndim:
return all(isinstance(v, slice) for v in val)
return False

# Apply the slicing and return the result using the custom leaf function
return jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf)


class AdditiveBasis(Basis):
"""
Expand Down

0 comments on commit 8f1b832

Please sign in to comment.