Skip to content

Commit

Permalink
linted
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 13, 2023
1 parent 0a57674 commit 8ebf2dc
Showing 1 changed file with 85 additions and 2 deletions.
87 changes: 85 additions & 2 deletions docs/examples/plot_1D_basis_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@
# -----------------
# Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
# please refer to the [Code References](../../../reference/neurostatslib/basis). After instantiation, all classes
# share the same syntax for basis evaluation. The following is an example of how to instantiate and
# evaluate a log-spaced cosine raised function basis.
# share the same syntax for basis evaluation.
#
# ### The Log-Spaced Raised Cosine Basis
# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis.

# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
raised_cosine_log = nsl.basis.RaisedCosineBasisLog(n_basis_funcs=10)
Expand All @@ -81,3 +83,84 @@
plt.plot(samples, eval_basis)
plt.show()

# %%
# ### The Fourier Basis
# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and
# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals.
# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and
# interpretation of the model parameters, each of which will represent the relative contribution of a specific
# oscillation frequency to the overall signal.


# A Fourier basis can be instantiated with the usual syntax.
# The user can pass the desired frequencies for the basis or
# the frequencies will be set to np.arange(n_basis_funcs//2).
# The number of basis function is required to be even.
fourier_basis = nsl.basis.FourierBasis(n_freqs=4)

# evaluate on equi-spaced samples
samples, eval_basis = fourier_basis.evaluate_on_grid(1000)

# plot the `sin` and `cos` separately
plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.title("Cos")
plt.plot(samples, eval_basis[:, :4])
plt.subplot(122)
plt.title("Sin")
plt.plot(samples, eval_basis[:, 4:])
plt.tight_layout()

# %%
# !!! note "Fourier basis convolution and Fourier transform"
# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is
# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$
# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $.
#
# When computing the cross-correlation of a signal with the Fourier basis functions,
# we essentially measure how well the signal correlates with sinusoids of different frequencies,
# within a specified temporal window. This process mirrors the operation performed by the Fourier transform.
# Therefore, it becomes clear that computing the cross-correlation of a signal with the Fourier basis defined here
# is equivalent to computing the discrete Fourier transform on a sliding window of the same size
# as that of the basis.

n_samples = 1000
n_freqs = 20

# define a signal
signal = np.random.normal(size=n_samples)

# evaluate the basis
_, eval_basis = nsl.basis.FourierBasis(n_freqs=n_freqs).evaluate_on_grid(n_samples)

# compute the cross-corr with the signal and the basis
# Note that we are inverting the time axis of the basis because we are aiming
# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis.
xcorr = np.array(
[
np.convolve(eval_basis[::-1, k], signal, mode="valid")[0]
for k in range(2 * n_freqs - 1)
]
)

# compute the power (add back sin(0 * t) = 0)
fft_complex = np.fft.fft(signal)
fft_amplitude = np.abs(fft_complex[:n_freqs])
fft_phase = np.angle(fft_complex[:n_freqs])
# compute the phase and amplitude from the convolution
xcorr_phase = np.arctan2(np.hstack([[0], xcorr[n_freqs:]]), xcorr[:n_freqs])
xcorr_aplitude = np.sqrt(xcorr[:n_freqs] ** 2 + np.hstack([[0], xcorr[n_freqs:]]) ** 2)

fig, ax = plt.subplots(1, 2)
ax[0].set_aspect("equal")
ax[0].set_title("Signal amplitude")
ax[0].scatter(fft_amplitude, xcorr_aplitude)
ax[0].set_xlabel("FFT")
ax[0].set_ylabel("cross-correlation")

ax[1].set_aspect("equal")
ax[1].set_title("Signal phase")
ax[1].scatter(fft_phase, xcorr_phase)
ax[1].set_xlabel("FFT")
ax[1].set_ylabel("cross-correlation")
plt.tight_layout()

0 comments on commit 8ebf2dc

Please sign in to comment.