Skip to content

Commit

Permalink
last few comment fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pranmod01 committed Oct 26, 2024
1 parent e2cf5dc commit c65383c
Showing 1 changed file with 17 additions and 48 deletions.
65 changes: 17 additions & 48 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ def fit(self, X: FeatureMatrix, y=None):
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
# Example input
>>> X, y = np.random.normal(size=(100, 2)), np.random.uniform(size=100)
>>> # Example input
>>> X = np.random.normal(size=(100, 2))
# Define and fit tranformation basis
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X) # input must be a 2d array
>>> transformer_fitted = transformer.fit(X)
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand All @@ -243,21 +243,21 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X, y = np.random.normal(size=(10000, 2)), np.random.uniform(size=100)
>>> X = np.random.normal(size=(10000, 2))
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10, mode="conv", window_size=200)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_
>>> transformer_fitted = transformer.fit(X) # input must be a 2d array
>>> transformer_fitted = transformer.fit(X)
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)
>>> # Transform basis
>>> feature_transformed = transformer.transform(X[:, 0:1]) # input must be a 2d array, (num_samples, 1)
>>> feature_transformed = transformer.transform(X[:, 0:1])
"""
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
Expand Down Expand Up @@ -290,14 +290,14 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X, y = np.random.normal(size=(100, 1)), np.random.uniform(size=100)
>>> X = np.random.normal(size=(100, 1))
>>> # Define tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> # Fit and transform basis
>>> feature_transformed = transformer.fit_transform(X) # input must be a 2d array, (num_samples, 1)
>>> feature_transformed = transformer.fit_transform(X)
"""
return self._basis.compute_features(*self._unpack_inputs(X))

Expand Down Expand Up @@ -953,14 +953,10 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
>>> from nemos.basis import MSplineBasis
>>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100)
>>> for i in range(4):
... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}')
>>> plt.title('M-Spline Basis Functions')
Text(0.5, 1.0, 'M-Spline Basis Functions')
>>> plt.xlabel('Domain')
Text(0.5, 0, 'Domain')
>>> plt.ylabel('Basis Function Value')
Text(0, 0.5, 'Basis Function Value')
>>> p = plt.plot(sample_points, basis_values, label=f'Function {i+1}')
>>> plt.title('M-Spline Basis Functions');
>>> plt.xlabel('Domain');
>>> plt.ylabel('Basis Function Value');
>>> l = plt.legend()
"""
self._check_input_dimensionality(n_samples)
Expand Down Expand Up @@ -1156,24 +1152,17 @@ class AdditiveBasis(Basis):
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X, y = np.random.normal(size=(30, 2)), np.random.poisson(size=30)
>>> # X.shape is (n_samples, n_inputs), where n_inputs is the number required by the basis
>>> X = np.random.normal(size=(30, 2))
>>> # define two basis objects and add them
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> additive_basis = nmo.basis.AdditiveBasis(basis1=basis_1, basis2=basis_2)
>>> transformed_X = additive_basis.to_transformer().transform(X)
>>> print(transformed_X.shape)
(30, 25)
>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> additive_basis_2 = additive_basis + basis_3
>>> transformed_X = additive_basis_2.to_transformer().transform(X)
>>> print(transformed_X.shape)
(30, 125)
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1290,23 +1279,17 @@ class MultiplicativeBasis(Basis):
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X, y = np.random.normal(size=(30, 3)), np.random.poisson(size=30)
>>> X = np.random.normal(size=(30, 3))
>>> # define two basis and multiply
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> multiplicative_basis = nmo.basis.MultiplicativeBasis(basis1=basis_1, basis2=basis_2)
>>> transformed_X = multiplicative_basis.to_transformer().transform(X[:, 0:2])
>>> print(transformed_X.shape)
(30, 150)
>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> multiplicative_basis_2 = multiplicative_basis * basis_3
>>> transformed_X = multiplicative_basis_2.to_transformer().transform(X)
>>> print(transformed_X.shape)
(30, 15000)
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1743,8 +1726,6 @@ class BSplineBasis(SplineBasis):
>>> from nemos.basis import BSplineBasis
>>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3)
>>> bspline_transformer = bspline_basis.to_transformer()
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis(sample_points)
"""
Expand Down Expand Up @@ -1879,9 +1860,6 @@ class CyclicBSplineBasis(SplineBasis):
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cyclic_basis(sample_points)
>>> X_transformed = cyclic_basis.to_transformer().fit_transform(X)
>>> X_transformed.shape
(1000, 5)
"""

def __init__(
Expand Down Expand Up @@ -2035,9 +2013,6 @@ class RaisedCosineBasisLinear(Basis):
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
>>> X_transformed = cosine_basis.to_transformer().fit_transform(X)
>>> X_transformed.shape
(1000, 5)
# References
------------
Expand Down Expand Up @@ -2248,9 +2223,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
>>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
>>> X_transformed = cosine_basis.to_transformer().fit_transform(X)
>>> X_transformed.shape
(1000, 5)
# References
------------
Expand Down Expand Up @@ -2412,14 +2384,11 @@ class OrthExponentialBasis(Basis):
>>> from nemos.basis import OrthExponentialBasis
>>> X = np.random.normal(size=(1000, 1))
>>> n_basis_funcs = 5
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = ortho_basis(sample_points)
>>> X_transformed = ortho_basis.to_transformer().fit_transform(X)
>>> X_transformed.shape
(1000, 5)
"""

def __init__(
Expand Down Expand Up @@ -2692,7 +2661,7 @@ def bspline(
>>> from nemos.basis import bspline
>>> sample_points = linspace(0, 1, 100)
>>> knots = knots = BSplineBasis(10)._generate_knots(sample_points)
>>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1])
>>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline
>>> bspline_eval.shape
(100, 10)
Expand Down

0 comments on commit c65383c

Please sign in to comment.