diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index 2344d0ea..4af70b40 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -65,8 +65,10 @@ # ----------------- # Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description, # please refer to the [Code References](../../../reference/neurostatslib/basis). After instantiation, all classes -# share the same syntax for basis evaluation. The following is an example of how to instantiate and -# evaluate a log-spaced cosine raised function basis. +# share the same syntax for basis evaluation. +# +# ### The Log-Spaced Raised Cosine Basis +# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter raised_cosine_log = nsl.basis.RaisedCosineBasisLog(n_basis_funcs=10) @@ -81,3 +83,84 @@ plt.plot(samples, eval_basis) plt.show() +# %% +# ### The Fourier Basis +# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and +# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals. +# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and +# interpretation of the model parameters, each of which will represent the relative contribution of a specific +# oscillation frequency to the overall signal. + + +# A Fourier basis can be instantiated with the usual syntax. +# The user can pass the desired frequencies for the basis or +# the frequencies will be set to np.arange(n_basis_funcs//2). +# The number of basis function is required to be even. +fourier_basis = nsl.basis.FourierBasis(n_freqs=4) + +# evaluate on equi-spaced samples +samples, eval_basis = fourier_basis.evaluate_on_grid(1000) + +# plot the `sin` and `cos` separately +plt.figure(figsize=(6, 3)) +plt.subplot(121) +plt.title("Cos") +plt.plot(samples, eval_basis[:, :4]) +plt.subplot(122) +plt.title("Sin") +plt.plot(samples, eval_basis[:, 4:]) +plt.tight_layout() + +# %% +# !!! note "Fourier basis convolution and Fourier transform" +# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is +# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$ +# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $. +# +# When computing the cross-correlation of a signal with the Fourier basis functions, +# we essentially measure how well the signal correlates with sinusoids of different frequencies, +# within a specified temporal window. This process mirrors the operation performed by the Fourier transform. +# Therefore, it becomes clear that computing the cross-correlation of a signal with the Fourier basis defined here +# is equivalent to computing the discrete Fourier transform on a sliding window of the same size +# as that of the basis. + +n_samples = 1000 +n_freqs = 20 + +# define a signal +signal = np.random.normal(size=n_samples) + +# evaluate the basis +_, eval_basis = nsl.basis.FourierBasis(n_freqs=n_freqs).evaluate_on_grid(n_samples) + +# compute the cross-corr with the signal and the basis +# Note that we are inverting the time axis of the basis because we are aiming +# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis. +xcorr = np.array( + [ + np.convolve(eval_basis[::-1, k], signal, mode="valid")[0] + for k in range(2 * n_freqs - 1) + ] +) + +# compute the power (add back sin(0 * t) = 0) +fft_complex = np.fft.fft(signal) +fft_amplitude = np.abs(fft_complex[:n_freqs]) +fft_phase = np.angle(fft_complex[:n_freqs]) +# compute the phase and amplitude from the convolution +xcorr_phase = np.arctan2(np.hstack([[0], xcorr[n_freqs:]]), xcorr[:n_freqs]) +xcorr_aplitude = np.sqrt(xcorr[:n_freqs] ** 2 + np.hstack([[0], xcorr[n_freqs:]]) ** 2) + +fig, ax = plt.subplots(1, 2) +ax[0].set_aspect("equal") +ax[0].set_title("Signal amplitude") +ax[0].scatter(fft_amplitude, xcorr_aplitude) +ax[0].set_xlabel("FFT") +ax[0].set_ylabel("cross-correlation") + +ax[1].set_aspect("equal") +ax[1].set_title("Signal phase") +ax[1].scatter(fft_phase, xcorr_phase) +ax[1].set_xlabel("FFT") +ax[1].set_ylabel("cross-correlation") +plt.tight_layout()