Skip to content

Commit

Permalink
fixed some core issues with irasa_sprint
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 7, 2024
1 parent 8b5f85a commit 23abaac
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 52 deletions.
25 changes: 13 additions & 12 deletions examples/basic_functionality.ipynb

Large diffs are not rendered by default.

41 changes: 26 additions & 15 deletions examples/irasa_sprint.ipynb

Large diffs are not rendered by default.

183 changes: 183 additions & 0 deletions examples/test_td.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#%%
from neurodsp.sim import sim_combined
from neurodsp.utils import create_times
import numpy as np
import scipy.signal as dsp


fs = 1000
n_seconds = 60
duration=2
overlap=0.5

sim_components = {'sim_powerlaw': {'exponent' : -1},
'sim_oscillation': {'freq' : 10}}


sig = sim_combined(n_seconds=n_seconds, fs=fs, components=sim_components)
times = create_times(n_seconds=n_seconds, fs=fs)
#%%
import fractions
kwargs_psd = {'nperseg': duration*fs,
'noverlap': duration*fs*overlap}

resampling_factor = 1.5

rat = fractions.Fraction(str(resampling_factor))
up, down = rat.numerator, rat.denominator

# Much faster than FFT-based resampling
data_up = dsp.resample_poly(sig, up, down, axis=-1)
data_down = dsp.resample_poly(sig, down, up, axis=-1)

# Calculate an up/downsampled version of the PSD using same params as original
win_duration = 2
hop = 100

nperseg = int(np.floor(fs * win_duration))

win = dsp.windows.hann(nperseg)

SFT = dsp.ShortTimeFFT(win, hop=hop, fs=fs, scale_to='psd')
t_inc = SFT.T
psd = SFT.spectrogram(sig, detr='constant')

hop_up = int(hop * resampling_factor)
SFT_u = dsp.ShortTimeFFT.from_window('hann',
nperseg=nperseg,
fs=fs * resampling_factor,
noverlap=nperseg-hop_up)
psd_up = SFT_u.stft(data_up)


#%%
hop_dw = int(hop / resampling_factor)
noverlap=nperseg-hop_dw
N = len(data_down)
SFT_d = dsp.ShortTimeFFT.from_window('hann',
nperseg=nperseg,
fs=fs / resampling_factor,
noverlap=nperseg-hop_dw,
fft_mode='centered',)
psd_dw = SFT_d.spectrogram(data_down, p0=0, p1=(N-noverlap)//SFT_d.hop, k_offset=N//2)

# %%
psd_dw.shape

# %%
psd.shape
# %%
psd_up.shape


# %%
import numpy as np
import scipy.signal as dsp
import fractions
from neurodsp.sim import sim_oscillation, sim_powerlaw

# Example signal and parameters
# Set some general settings, to be used across all simulations
fs = 500
n_seconds = 15
duration=4
overlap=0.5

# Create a times vector for the simulations
#times = create_times(n_seconds, fs)


alpha = sim_oscillation(n_seconds=.5, fs=fs, freq=10)
no_alpha = np.zeros(len(alpha))
beta = sim_oscillation(n_seconds=.5, fs=fs, freq=25)
no_beta = np.zeros(len(beta))

exp_1 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=-1)
exp_2 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=-2)


alphas = np.concatenate([no_alpha, alpha, no_alpha, alpha, no_alpha])
betas = np.concatenate([beta, no_beta, beta, no_beta, beta])

sig = np.concatenate([exp_1 + alphas,
exp_1 + alphas + betas,
exp_1 + betas,
exp_2 + alphas,
exp_2 + alphas + betas,
exp_2 + betas, ])

resampling_factor = 1.1 # Resampling factor
win_duration = 1 # Window duration for spectrogram
hop = 10 # Hop size for spectrogram
hset = np.arange(1, 2, 0.05).round(2)
freq_res = 0.5
nfft = int(fs / freq_res)

# Original spectrogram
# Window and hop size for spectrogram
nperseg = int(np.floor(fs * win_duration))
win = dsp.windows.hann(nperseg)
freq, t, psd = dsp.stft(sig, fs=fs, nfft=nfft, window=win, nperseg=nperseg,
noverlap=nperseg - hop, scaling='psd')

psd = (np.abs(psd) ** 2)

average_psd = np.zeros([len(hset), *psd.shape])

for i, resampling_factor in enumerate(hset):
# Calculate resampling factors
rat = fractions.Fraction(str(resampling_factor))
up, down = rat.numerator, rat.denominator

# Resample the signal
data_up = dsp.resample_poly(sig, up, down, axis=-1)
data_down = dsp.resample_poly(sig, down, up, axis=-1)

# Upsampled spectrogram
hop_up = int(hop * resampling_factor)
f_up, t_up, psd_up = dsp.stft(data_up, nfft=nfft, fs=fs * resampling_factor,
window=win, nperseg=nperseg,
noverlap=nperseg - hop_up, scaling='psd')

# Downsampled spectrogram
hop_down = int(hop / resampling_factor)
f_dw, t_dw, psd_dw = dsp.stft(data_down, nfft=nfft, fs=fs / resampling_factor,
window=win, nperseg=nperseg,
noverlap=nperseg - hop_down, scaling='psd')

# Ensure the time axis has the same number of values by adjusting the hop size
psd_up = psd_up[:, :psd.shape[1]]
psd_dw = psd_dw[:, :psd.shape[1]]

# Average the PSDs
average_psd[i,:,:] = np.sqrt((np.abs(psd_up) ** 2) * (np.abs(psd_dw) ** 2))
# If needed, the time axis can be derived from the spectrogram output
#time_axis = np.linspace(0, len(sig) / fs, psd.shape[1])

print("Average PSD shape:", average_psd.shape)
#print("Time axis length:", len(time_axis))

aperiodic = np.median(average_psd, axis=0)
periodic = psd - aperiodic

#%%
from neurodsp.plts import plot_timefrequency#
import matplotlib.pyplot as plt
f, axes = plt.subplots(figsize=(14, 4), ncols=3)

fmask = freq < 50

plot_timefrequency(t, freq[fmask], psd[fmask,:], vmin=0, ax=axes[0])
plot_timefrequency(t, freq[fmask], aperiodic[fmask,:], vmin=0, ax=axes[1])
plot_timefrequency(t, freq[fmask], periodic[fmask,:], vmin=0, ax=axes[2])

# %%
from pyrasa.irasa import irasa_sprint

irasa_sprint_spectrum = irasa_sprint(sig[np.newaxis, :], fs=fs,
band=(1, 100),
freq_res=.5,
hop=100,
win_duration=1.,
hset_info=(1.05, 2., 0.05))
# %%
4 changes: 2 additions & 2 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def irasa_sprint( # noqa PLR0915 C901

hset = np.round(np.arange(*hset_info), hset_accuracy)

mfft = int(fs / freq_res)
nfft = int(fs / freq_res)
win_kwargs = {'win_func': win_func, 'win_func_kwargs': win_func_kwargs}
dpss_settings = {
'time_bandwidth': dpss_settings_time_bandwidth,
Expand All @@ -261,7 +261,7 @@ def irasa_sprint( # noqa PLR0915 C901
}

irasa_kwargs: IrasaSprintKwargsTyped = {
'mfft': mfft,
'nfft': nfft,
'hop': hop,
'win_duration': win_duration,
'dpss_settings': dpss_settings,
Expand Down
34 changes: 11 additions & 23 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import scipy.signal as dsp
from scipy.signal import ShortTimeFFT

from pyrasa.utils.types import IrasaFun

Expand Down Expand Up @@ -226,9 +225,9 @@ def _compute_psd_welch(
def _compute_sgramm( # noqa C901
x: np.ndarray,
fs: int,
mfft: int,
hop: int,
nfft: int,
win_duration: float,
hop: int,
dpss_settings: dict,
win_kwargs: dict,
up_down: str | None = None,
Expand All @@ -237,42 +236,31 @@ def _compute_sgramm( # noqa C901
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Function to compute spectrograms"""

if h is None:
nperseg = int(np.floor(fs * win_duration))
elif np.logical_and(h is not None, up_down == 'up'):
nperseg = int(np.floor(fs * win_duration * h))
nperseg = int(np.floor(fs * win_duration))

if np.logical_and(h is not None, up_down == 'up'):
hop = int(hop * h)
elif np.logical_and(h is not None, up_down == 'down'):
nperseg = int(np.floor(fs * win_duration / h))
hop = int(hop / h)

win, ratios = _get_windows(nperseg, dpss_settings, **win_kwargs)

sgramms = []
for cur_win in win:
SFT = ShortTimeFFT(cur_win, hop=hop, mfft=mfft, fs=fs, scale_to='psd') # noqa N806
cur_sgramm = SFT.spectrogram(x, detr='constant')
sgramms.append(cur_sgramm)
freq, time, sgramm = dsp.stft(
x, nfft=nfft, nperseg=nperseg, noverlap=nperseg - hop, fs=fs, window=cur_win, scaling='psd'
)
sgramm = np.abs(sgramm) ** 2
sgramms.append(sgramm)

if ratios is None:
sgramm = np.mean(sgramms, axis=0)
else:
weighted_sgramms = [ratios[ix] * cur_sgramm for ix, cur_sgramm in enumerate(sgramms)]
sgramm = np.sum(weighted_sgramms, axis=0) / np.sum(ratios)

time = _gen_time_from_sft(SFT, x)
freq = SFT.f[SFT.f > 0]

# subsample the upsampled data in the time domain to allow averaging
# This is necessary as division by h can cause slight rounding differences that
# result in actual unintended temporal differences in up/dw for very long segments.
if time_orig is not None:
sgramm = np.array([_find_nearest(sgramm, time, t) for t in time_orig])
max_t_ix = time_orig.shape[0]
# swapping axes is necessitated by _find_nearest
sgramm = np.swapaxes(
np.swapaxes(sgramm[:max_t_ix, :, :], 1, 2), 0, 2
) # cut time axis for up/downsampled data to allow averaging
sgramm = sgramm[:, :, : time_orig.shape[-1]]

sgramm = np.squeeze(sgramm) # bring in proper format

Expand Down

0 comments on commit 23abaac

Please sign in to comment.