Skip to content

Commit

Permalink
Merge branch 'basis_refactor_pr1' into document_basis
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 4, 2024
2 parents 0fa061f + fb883ce commit 4d2bcb3
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/background/basis/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ warnings.filterwarnings(

## Defining a 1D Basis Object

We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.basis.MSplineEval).
We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval).
The hyperparameters required to initialize this class are:

- The number of basis functions, which should be a positive integer.
Expand Down
8 changes: 4 additions & 4 deletions docs/developers_notes/04-basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Abstract Class Basis
└─ Concrete Subclass OrthExponentialBasis
```

The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`_evaluate`](nemos.basis._basis.Basis._evaluate) that is specific for each concrete class. See below for more details.
The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `_evaluate` that is specific for each concrete class. See below for more details.

## The Class `nemos.basis._basis.Basis`

Expand Down Expand Up @@ -61,14 +61,14 @@ This method performs the following steps:

1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case.
2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis.
3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method.
3. Calls the `_evaluate` method.
4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid.

### Abstract Methods

The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement:

1. [`_evaluate`](nemos.basis._basis.Basis._evaluate): Evaluates a basis over some specified samples.
1. `_evaluate`: Evaluates a basis over some specified samples.
2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis.

## Contributors Guidelines
Expand All @@ -77,7 +77,7 @@ The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the followi
To write a usable (i.e., concrete, non-abstract) basis object, you

- **Must** inherit the abstract superclass [`Basis`](nemos.basis._basis.Basis)
- **Must** define the [`_evaluate`](nemos.basis._basis.Basis._evaluate) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics.
- **Must** define the `_evaluate` and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics.
- **Should not** overwrite the [`compute_features`](nemos.basis._basis.Basis.compute_features) and [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis._basis.Basis).
- **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis._spline_basis.SplineBasis)).

2 changes: 1 addition & 1 deletion docs/how_to_guide/plot_06_glm_pytree.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ Okay, let's use unit number 7.
Now let's set up our design matrix. First, let's fit the head direction by
itself. Head direction is a circular variable (pi and -pi are adjacent to
each other), so we need to use a basis that has this property as well.
[`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) is one such basis.
[`CyclicBSplineEval`](nemos.basis.CyclicBSplineEval) is one such basis.

Let's create our basis and then arrange our data properly.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/plot_03_grid_cells.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ position = position.interpolate(counts)
```

We can define a two-dimensional basis for position by multiplying two one-dimensional bases,
see [here](../../background/plot_02_ND_basis_function) for more details.
see [here](composing_basis_function) for more details.

```{code-cell} ipython3
basis_2d = nmo.basis.BSplineEval(
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/plot_05_place_cells.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ print(count.shape)

For each feature, we will use a different set of basis :

- position : [`MSplineEval`](nemos.basis.basis.MSplineEval)
- theta phase : [`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval)
- speed : [`MSplineEval`](nemos.basis.basis.MSplineEval)
- position : [`MSplineEval`](nemos.basis.MSplineEval)
- theta phase : [`CyclicBSplineEval`](nemos.basis.CyclicBSplineEval)
- speed : [`MSplineEval`](nemos.basis.MSplineEval)


```{code-cell} ipython3
Expand Down
8 changes: 5 additions & 3 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Mixin classes for basis."""

from __future__ import annotations

import copy
import inspect
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -95,10 +97,10 @@ def _compute_features(self, *xi: ArrayLike):
samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions
except for the sample-axis are flattened, so that the method always returns a matrix.
For example, if samples are of shape (num_samples, 2, 3), the output will be
(num_samples, num_basis_funcs * 2 * 3).
``(num_samples, num_basis_funcs * 2 * 3)``.
The time-axis can be specified at basis initialization by setting the keyword argument ``axis``.
For example, if ``axis == 1`` your samples should be (N1, num_samples N3, ...), the output of
transform will be (num_samples, num_basis_funcs * N1 * N3 *...).
For example, if ``axis == 1`` your samples should be ``(N1, num_samples N3, ...)``, the output of
transform will be ``(num_samples, num_basis_funcs * N1 * N3 *...)``.
Parameters
----------
Expand Down

0 comments on commit 4d2bcb3

Please sign in to comment.