Skip to content

Commit

Permalink
fixed aperiodic error estimate for mutliple channels
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 27, 2024
1 parent 315860d commit c7fe3c0
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 53 deletions.
84 changes: 42 additions & 42 deletions examples/hset_optimization.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/irasa_mne.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down
25 changes: 15 additions & 10 deletions pyrasa/utils/irasa_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,24 @@ def get_aperiodic_error(self, peak_kwargs: None | dict = None) -> np.ndarray:

if peak_kwargs is None:
peak_kwargs = {}

aperiodic_errors = []
# get absolute periodic spectrum
aperiodic_error = np.abs(self.periodic[0, :])
for ix in range(self.periodic.shape[0]):
aperiodic_error = np.abs(self.periodic[ix, :])

# zero-out peaks
peaks = self.get_peaks(**peak_kwargs)
freqs = self.freqs

# zero-out peaks
peaks = self.get_peaks(**peak_kwargs)
freqs = self.freqs
for _, peak in peaks.iterrows():
cur_upper = peak['cf'] + peak['bw']
cur_lower = peak['cf'] - peak['bw']

for _, peak in peaks.iterrows():
cur_upper = peak['cf'] + peak['bw']
cur_lower = peak['cf'] - peak['bw']
freq_mask = np.logical_and(freqs < cur_upper, freqs > cur_lower)

freq_mask = np.logical_and(freqs < cur_upper, freqs > cur_lower)
aperiodic_error[freq_mask] = 0

aperiodic_error[freq_mask] = 0
aperiodic_errors.append(aperiodic_error)

return aperiodic_error
return np.array(aperiodic_errors)

0 comments on commit c7fe3c0

Please sign in to comment.