From c65383cc129d0bbe210d77d6acc119caafb8c260 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Sat, 26 Oct 2024 14:02:52 -0700 Subject: [PATCH] last few comment fixes --- src/nemos/basis.py | 65 ++++++++++++---------------------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 2a9e12dc..05a2258a 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -210,13 +210,13 @@ def fit(self, X: FeatureMatrix, y=None): >>> import numpy as np >>> from nemos.basis import MSplineBasis, TransformerBasis - # Example input - >>> X, y = np.random.normal(size=(100, 2)), np.random.uniform(size=100) + >>> # Example input + >>> X = np.random.normal(size=(100, 2)) - # Define and fit tranformation basis + >>> # Define and fit tranformation basis >>> basis = MSplineBasis(10) >>> transformer = TransformerBasis(basis) - >>> transformer_fitted = transformer.fit(X) # input must be a 2d array + >>> transformer_fitted = transformer.fit(X) """ self._basis._set_kernel(*self._unpack_inputs(X)) return self @@ -243,7 +243,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> from nemos.basis import MSplineBasis, TransformerBasis >>> # Example input - >>> X, y = np.random.normal(size=(10000, 2)), np.random.uniform(size=100) + >>> X = np.random.normal(size=(10000, 2)) >>> # Define and fit tranformation basis >>> basis = MSplineBasis(10, mode="conv", window_size=200) @@ -251,13 +251,13 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ - >>> transformer_fitted = transformer.fit(X) # input must be a 2d array + >>> transformer_fitted = transformer.fit(X) >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) >>> transformer_fitted.kernel_.shape (200, 10) >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) # input must be a 2d array, (num_samples, 1) + >>> feature_transformed = transformer.transform(X[:, 0:1]) """ # transpose does not work with pynapple # can't use func(*X.T) to unwrap @@ -290,14 +290,14 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> from nemos.basis import MSplineBasis, TransformerBasis >>> # Example input - >>> X, y = np.random.normal(size=(100, 1)), np.random.uniform(size=100) + >>> X = np.random.normal(size=(100, 1)) >>> # Define tranformation basis >>> basis = MSplineBasis(10) >>> transformer = TransformerBasis(basis) >>> # Fit and transform basis - >>> feature_transformed = transformer.fit_transform(X) # input must be a 2d array, (num_samples, 1) + >>> feature_transformed = transformer.fit_transform(X) """ return self._basis.compute_features(*self._unpack_inputs(X)) @@ -953,14 +953,10 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: >>> from nemos.basis import MSplineBasis >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> for i in range(4): - ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') - >>> plt.title('M-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') - >>> plt.xlabel('Domain') - Text(0.5, 0, 'Domain') - >>> plt.ylabel('Basis Function Value') - Text(0, 0.5, 'Basis Function Value') + >>> p = plt.plot(sample_points, basis_values, label=f'Function {i+1}') + >>> plt.title('M-Spline Basis Functions'); + >>> plt.xlabel('Domain'); + >>> plt.ylabel('Basis Function Value'); >>> l = plt.legend() """ self._check_input_dimensionality(n_samples) @@ -1156,24 +1152,17 @@ class AdditiveBasis(Basis): >>> # Generate sample data >>> import numpy as np >>> import nemos as nmo - >>> X, y = np.random.normal(size=(30, 2)), np.random.poisson(size=30) - >>> # X.shape is (n_samples, n_inputs), where n_inputs is the number required by the basis + >>> X = np.random.normal(size=(30, 2)) >>> # define two basis objects and add them >>> basis_1 = nmo.basis.BSplineBasis(10) >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) >>> additive_basis = nmo.basis.AdditiveBasis(basis1=basis_1, basis2=basis_2) - >>> transformed_X = additive_basis.to_transformer().transform(X) - >>> print(transformed_X.shape) - (30, 25) >>> # can add another basis to the AdditiveBasis object >>> X = np.random.normal(size=(30, 3)) >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) >>> additive_basis_2 = additive_basis + basis_3 - >>> transformed_X = additive_basis_2.to_transformer().transform(X) - >>> print(transformed_X.shape) - (30, 125) """ def __init__(self, basis1: Basis, basis2: Basis) -> None: @@ -1290,23 +1279,17 @@ class MultiplicativeBasis(Basis): >>> # Generate sample data >>> import numpy as np >>> import nemos as nmo - >>> X, y = np.random.normal(size=(30, 3)), np.random.poisson(size=30) + >>> X = np.random.normal(size=(30, 3)) >>> # define two basis and multiply >>> basis_1 = nmo.basis.BSplineBasis(10) >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) >>> multiplicative_basis = nmo.basis.MultiplicativeBasis(basis1=basis_1, basis2=basis_2) - >>> transformed_X = multiplicative_basis.to_transformer().transform(X[:, 0:2]) - >>> print(transformed_X.shape) - (30, 150) >>> # Can multiply or add another basis to the AdditiveBasis object >>> # This will cause the number of output features of the result basis to grow accordingly >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) >>> multiplicative_basis_2 = multiplicative_basis * basis_3 - >>> transformed_X = multiplicative_basis_2.to_transformer().transform(X) - >>> print(transformed_X.shape) - (30, 15000) """ def __init__(self, basis1: Basis, basis2: Basis) -> None: @@ -1743,8 +1726,6 @@ class BSplineBasis(SplineBasis): >>> from nemos.basis import BSplineBasis >>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3) - >>> bspline_transformer = bspline_basis.to_transformer() - >>> sample_points = linspace(0, 1, 100) >>> basis_functions = bspline_basis(sample_points) """ @@ -1879,9 +1860,6 @@ class CyclicBSplineBasis(SplineBasis): >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10) >>> sample_points = linspace(0, 1, 100) >>> basis_functions = cyclic_basis(sample_points) - >>> X_transformed = cyclic_basis.to_transformer().fit_transform(X) - >>> X_transformed.shape - (1000, 5) """ def __init__( @@ -2035,9 +2013,6 @@ class RaisedCosineBasisLinear(Basis): >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) >>> sample_points = linspace(0, 1, 100) >>> basis_functions = cosine_basis(sample_points) - >>> X_transformed = cosine_basis.to_transformer().fit_transform(X) - >>> X_transformed.shape - (1000, 5) # References ------------ @@ -2248,9 +2223,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): >>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10) >>> sample_points = linspace(0, 1, 100) >>> basis_functions = cosine_basis(sample_points) - >>> X_transformed = cosine_basis.to_transformer().fit_transform(X) - >>> X_transformed.shape - (1000, 5) # References ------------ @@ -2412,14 +2384,11 @@ class OrthExponentialBasis(Basis): >>> from nemos.basis import OrthExponentialBasis >>> X = np.random.normal(size=(1000, 1)) >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates >>> window_size=10 >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) >>> sample_points = linspace(0, 1, 100) >>> basis_functions = ortho_basis(sample_points) - >>> X_transformed = ortho_basis.to_transformer().fit_transform(X) - >>> X_transformed.shape - (1000, 5) """ def __init__( @@ -2692,7 +2661,7 @@ def bspline( >>> from nemos.basis import bspline >>> sample_points = linspace(0, 1, 100) - >>> knots = knots = BSplineBasis(10)._generate_knots(sample_points) + >>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1]) >>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline >>> bspline_eval.shape (100, 10)