diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index bddf798d..29f32993 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -46,7 +46,7 @@ warnings.filterwarnings( ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.basis.MSplineEval). +We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval). The hyperparameters required to initialize this class are: - The number of basis functions, which should be a positive integer. diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index f7823a86..45decfb1 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -26,7 +26,7 @@ Abstract Class Basis └─ Concrete Subclass OrthExponentialBasis ``` -The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`_evaluate`](nemos.basis._basis.Basis._evaluate) that is specific for each concrete class. See below for more details. +The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `_evaluate` that is specific for each concrete class. See below for more details. ## The Class `nemos.basis._basis.Basis` @@ -61,14 +61,14 @@ This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method. +3. Calls the `_evaluate` method. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement: -1. [`_evaluate`](nemos.basis._basis.Basis._evaluate): Evaluates a basis over some specified samples. +1. `_evaluate`: Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines @@ -77,7 +77,7 @@ The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the followi To write a usable (i.e., concrete, non-abstract) basis object, you - **Must** inherit the abstract superclass [`Basis`](nemos.basis._basis.Basis) -- **Must** define the [`_evaluate`](nemos.basis._basis.Basis._evaluate) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. +- **Must** define the `_evaluate` and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. - **Should not** overwrite the [`compute_features`](nemos.basis._basis.Basis.compute_features) and [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis._basis.Basis). - **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis._spline_basis.SplineBasis)). diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 36910bc3..e5949f58 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -274,7 +274,7 @@ Okay, let's use unit number 7. Now let's set up our design matrix. First, let's fit the head direction by itself. Head direction is a circular variable (pi and -pi are adjacent to each other), so we need to use a basis that has this property as well. -[`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) is one such basis. +[`CyclicBSplineEval`](nemos.basis.CyclicBSplineEval) is one such basis. Let's create our basis and then arrange our data properly. diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index ec9c3fca..6d244f28 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -145,7 +145,7 @@ position = position.interpolate(counts) ``` We can define a two-dimensional basis for position by multiplying two one-dimensional bases, -see [here](../../background/plot_02_ND_basis_function) for more details. +see [here](composing_basis_function) for more details. ```{code-cell} ipython3 basis_2d = nmo.basis.BSplineEval( diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index e1d492b2..4657c0f5 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -336,9 +336,9 @@ print(count.shape) For each feature, we will use a different set of basis : - - position : [`MSplineEval`](nemos.basis.basis.MSplineEval) - - theta phase : [`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) - - speed : [`MSplineEval`](nemos.basis.basis.MSplineEval) + - position : [`MSplineEval`](nemos.basis.MSplineEval) + - theta phase : [`CyclicBSplineEval`](nemos.basis.CyclicBSplineEval) + - speed : [`MSplineEval`](nemos.basis.MSplineEval) ```{code-cell} ipython3 diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 3ae84376..16332e6e 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -1,5 +1,7 @@ """Mixin classes for basis.""" +from __future__ import annotations + import copy import inspect from typing import Optional, Tuple, Union @@ -95,10 +97,10 @@ def _compute_features(self, *xi: ArrayLike): samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions except for the sample-axis are flattened, so that the method always returns a matrix. For example, if samples are of shape (num_samples, 2, 3), the output will be - (num_samples, num_basis_funcs * 2 * 3). + ``(num_samples, num_basis_funcs * 2 * 3)``. The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. - For example, if ``axis == 1`` your samples should be (N1, num_samples N3, ...), the output of - transform will be (num_samples, num_basis_funcs * N1 * N3 *...). + For example, if ``axis == 1`` your samples should be ``(N1, num_samples N3, ...)``, the output of + transform will be ``(num_samples, num_basis_funcs * N1 * N3 *...)``. Parameters ----------