Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add option to store and return TFR taper weights #12910

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9fe1fb6
Add option to store and return tfr taper weights
tsbinns Oct 22, 2024
45c6a0b
Merge remote-tracking branch 'upstream/main' into add_tfr_weights
tsbinns Oct 22, 2024
82fc2f7
Update docstrings
tsbinns Oct 22, 2024
9f30a59
Merge branch 'main' into add_tfr_weights
tsbinns Oct 22, 2024
a49f934
Remove whitespace
tsbinns Oct 22, 2024
48afced
Merge branch 'add_tfr_weights' of https://github.com/tsbinns/mne-pyth…
tsbinns Oct 22, 2024
7c3dcfa
Add PR num
tsbinns Oct 22, 2024
8c16716
Revert "Update docstrings"
tsbinns Oct 22, 2024
51b8cd0
Remove outdated default setting
tsbinns Oct 22, 2024
2f9a4b4
Reapply "Update docstrings"
tsbinns Oct 22, 2024
b4537b2
Update docstrings
tsbinns Oct 22, 2024
f155238
Merge branch 'main' into add_tfr_weights
tsbinns Oct 24, 2024
2a03e9b
Merge branch 'main' into add_tfr_weights
tsbinns Oct 28, 2024
045d9a2
Merge branch 'main' into add_tfr_weights
tsbinns Oct 29, 2024
8d645bb
Enforce return_weights as named param
tsbinns Oct 29, 2024
5ad9bd5
Merge branch 'main' into add_tfr_weights
tsbinns Dec 9, 2024
1c02b40
Add missing test coverage
tsbinns Dec 9, 2024
54f2a32
Add changelog entry
tsbinns Dec 9, 2024
6a23556
Merge branch 'fix_tfr_tapers' into fix_tfr_multitapers
tsbinns Dec 9, 2024
a107991
Begin add support for tapers in array objs
tsbinns Dec 9, 2024
01c486c
Begin add support for tapers in array objs
tsbinns Dec 9, 2024
ca27179
Fix docstring entries
tsbinns Dec 9, 2024
b14a100
Fix faulty state check
tsbinns Dec 10, 2024
972aba2
Add weights to AverageTFR
tsbinns Dec 10, 2024
e11fa2b
Expand test coverage
tsbinns Dec 10, 2024
aaef4b7
Merge branch 'main' into add_tfr_weights
tsbinns Dec 10, 2024
999d122
Disallow aggregating tapers in combine_tfr
tsbinns Dec 10, 2024
e12b09a
Updated docstrings
tsbinns Dec 10, 2024
dd61955
Merge branch 'main' into add_tfr_weights
tsbinns Dec 10, 2024
728701e
Add placeholder versionadded tags
tsbinns Dec 10, 2024
6af3310
Merge branch 'add_tfr_weights' of https://github.com/tsbinns/mne-pyth…
tsbinns Dec 10, 2024
e3a3c4b
Merge remote-tracking branch 'upstream/main' into add_tfr_weights
tsbinns Dec 11, 2024
de39d25
Begin fixing to_data_frame
tsbinns Dec 11, 2024
80126a7
Fix to_data_frame bug with tapers
tsbinns Dec 11, 2024
82dfab9
Fix plotting with tapers
tsbinns Dec 14, 2024
5b150aa
Merge branch 'main' into add_tfr_weights
tsbinns Dec 14, 2024
0d3d85d
Merge branch 'main' into add_tfr_weights
tsbinns Dec 19, 2024
012bd94
Add version tag
tsbinns Dec 20, 2024
e5eedee
Merge branch 'main' into add_tfr_weights
tsbinns Jan 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def tfr_array_multitaper(
use_fft=True,
decim=1,
output="complex",
return_weights=False,
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
n_jobs=None,
*,
verbose=None,
Expand Down Expand Up @@ -502,6 +503,13 @@ def tfr_array_multitaper(
* ``'itc'`` : inter-trial coherence.
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.

return_weights : bool, default False
If True, return the taper weights. Only applies if ``output='complex'`` or
``'phase'``.

.. versionadded:: 1.9.0

%(n_jobs)s
The parallelization is implemented across channels.
%(verbose)s
Expand All @@ -520,6 +528,9 @@ def tfr_array_multitaper(
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
``return_weights=True``.

See Also
--------
Expand Down Expand Up @@ -550,6 +561,7 @@ def tfr_array_multitaper(
use_fft=use_fft,
decim=decim,
output=output,
return_weights=return_weights,
n_jobs=n_jobs,
verbose=verbose,
)
14 changes: 9 additions & 5 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,21 @@ def test_tfr_morlet():
def test_dpsswavelet():
"""Test DPSS tapers."""
freqs = np.arange(5, 25, 3)
Ws = _make_dpss(
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True
Ws, weights = _make_dpss(
1000,
freqs=freqs,
n_cycles=freqs / 2.0,
time_bandwidth=4.0,
zero_mean=True,
return_weights=True,
)

assert len(Ws) == 3 # 3 tapers expected
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs)

# Check that zero mean is true
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5

assert len(Ws[0]) == len(freqs) # As many wavelets as asked for


@pytest.mark.slowtest
def test_tfr_multitaper():
Expand Down
97 changes: 75 additions & 22 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,11 @@ def _make_dpss(
-------
Ws : list of array
The wavelets time series.
Cs : list of array
The concentration weights. Only returned if return_weights=True.
"""
Ws = list()
Cs = list()

freqs = np.array(freqs)
if np.any(freqs <= 0):
Expand All @@ -281,6 +284,7 @@ def _make_dpss(

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
Expand All @@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I am somewhat unsure on. The existing implementation is to just use conc as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed for the PSD computation that the square root of the weights is also taken, so I think this is okay:

weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]


Wm.append(Wk)
Cm.append(Ck)

Ws.append(Wm)
Cs.append(Cm)
if return_weights:
return Ws, conc
return Ws, Cs
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
return Ws


Expand Down Expand Up @@ -428,6 +435,7 @@ def _compute_tfr(
use_fft=True,
decim=1,
output="complex",
return_weights=False,
n_jobs=None,
*,
verbose=None,
Expand Down Expand Up @@ -479,6 +487,9 @@ def _compute_tfr(
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.

return_weights : bool, default False
Whether to return the taper weights. Only applies if method='multitaper' and
output='complex' or 'phase'.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels.
Expand All @@ -495,6 +506,10 @@ def _compute_tfr(
n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the
real values in the ``output`` contain average power' and the imaginary
values contain the ITC: ``out = avg_power + i * itc``.

weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if method='multitaper', output='complex' or
'phase', and return_weights=True.
"""
# Check data
epoch_data = np.asarray(epoch_data)
Expand All @@ -516,6 +531,9 @@ def _compute_tfr(
decim,
output,
)
return_weights = (
return_weights and method == "multitaper" and output in ["complex", "phase"]
)

decim = _ensure_slice(decim)
if (freqs > sfreq / 2.0).any():
Expand All @@ -531,13 +549,18 @@ def _compute_tfr(
Ws = [W] # to have same dimensionality as the 'multitaper' case

elif method == "multitaper":
Ws = _make_dpss(
out = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
return_weights=return_weights,
)
if return_weights:
Ws, weights = out
else:
Ws = out

# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
Expand All @@ -561,6 +584,8 @@ def _compute_tfr(
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
if return_weights:
weights = np.array(weights)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -585,6 +610,9 @@ def _compute_tfr(
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)

if return_weights:
return out, weights
return out


Expand Down Expand Up @@ -1187,9 +1215,6 @@ def __init__(
f'{classname} got unsupported parameter value{_pl(problem)} '
f'{" and ".join(problem)}.'
)
# shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release)
if method == "morlet":
method_kw.setdefault("zero_mean", True)
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
# check method
valid_methods = ["morlet", "multitaper"]
if isinstance(inst, BaseEpochs):
Expand All @@ -1203,6 +1228,9 @@ def __init__(
method_kw.setdefault("output", "power")
self._freqs = np.asarray(freqs, dtype=np.float64)
del freqs
# always store weights for per-taper outputs
if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]:
method_kw["return_weights"] = True
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
# check validity of kwargs manually to save compute time if any are invalid
tfr_funcs = dict(
morlet=tfr_array_morlet,
Expand All @@ -1224,6 +1252,7 @@ def __init__(
self._method = method
self._inst_type = type(inst)
self._baseline = None
self._weights = None
self.preload = True # needed for __getitem__, never False for TFRs
# self._dims may also get updated by child classes
self._dims = ["channel", "freq", "time"]
Expand Down Expand Up @@ -1382,6 +1411,7 @@ def __getstate__(self):
info=self.info,
baseline=self._baseline,
decim=self._decim,
weights=self._weights,
)

def __setstate__(self, state):
Expand Down Expand Up @@ -1410,6 +1440,7 @@ def __setstate__(self, state):
self._decim = defaults["decim"]
self.preload = True
self._set_times(self._raw_times)
self._weights = state.get("weights") # objs saved before #XXX won't have
# Handle instance type. Prior to gh-11282, Raw was not a possibility so if
# `inst_type_str` is missing it must be Epochs or Evoked
unknown_class = Epochs if "epoch" in self._dims else Evoked
Expand Down Expand Up @@ -1516,6 +1547,10 @@ def _compute_tfr(self, data, n_jobs, verbose):
if self.method == "stockwell":
self._data, self._itc, freqs = result
assert np.array_equal(self._freqs, freqs)
elif self.method == "multitaper" and self._tfr_func.keywords.get(
"output", ""
) in ["complex", "phase"]:
self._data, self._weights = result
elif self._tfr_func.keywords.get("output", "").endswith("_itc"):
self._data, self._itc = result.real, result.imag
else:
Expand Down Expand Up @@ -1694,6 +1729,11 @@ def times(self):
"""The time points present in the data (in seconds)."""
return self._times_readonly

@property
def weights(self):
"""The weights used for each taper in the time-frequency estimates."""
return self._weights

@fill_doc
def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True):
"""Crop data to a given time interval in place.
Expand Down Expand Up @@ -2654,42 +2694,55 @@ def to_data_frame(
"""
# check pandas once here, instead of in each private utils function
pd = _check_pandas_installed() # noqa
# triage for Epoch-derived or unaggregated spectra
from_epo = isinstance(self, EpochsTFR)
unagg_mt = "taper" in self._dims
# arg checking
valid_index_args = ["time", "freq"]
if isinstance(self, EpochsTFR):
if from_epo:
valid_index_args.extend(["epoch", "condition"])
valid_time_formats = ["ms", "timedelta"]
index = _check_pandas_index_arguments(index, valid_index_args)
time_format = _check_time_format(time_format, valid_time_formats)
# get data
picks = _picks_to_idx(self.info, picks, "all", exclude=())
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
axis = self._dims.index("channel")
if not isinstance(self, EpochsTFR):
ch_axis = self._dims.index("channel")
if not from_epo:
data = data[np.newaxis] # add singleton "epochs" axis
axis += 1
n_epochs, n_picks, n_freqs, n_times = data.shape
# reshape to (epochs*freqs*times) x signals
data = np.moveaxis(data, axis, -1)
data = data.reshape(n_epochs * n_freqs * n_times, n_picks)
ch_axis += 1
if not unagg_mt:
data = np.expand_dims(data, -3) # add singleton "tapers" axis
n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape
# reshape to (epochs*tapers*freqs*times) x signals
data = np.moveaxis(data, ch_axis, -1)
data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks)
# prepare extra columns / multiindex
mindex = list()
default_index = list()
times = _convert_times(times, time_format, self.info["meas_date"])
times = np.tile(times, n_epochs * n_freqs)
freqs = np.tile(np.repeat(freqs, n_times), n_epochs)
times = np.tile(times, n_epochs * n_freqs * n_tapers)
freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs)
mindex.append(("time", times))
mindex.append(("freq", freqs))
if isinstance(self, EpochsTFR):
mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs)))
if from_epo:
mindex.append(
("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers))
)
rev_event_id = {v: k for k, v in self.event_id.items()}
conditions = [rev_event_id[k] for k in self.events[:, 2]]
mindex.append(("condition", np.repeat(conditions, n_times * n_freqs)))
mindex.append(
("condition", np.repeat(conditions, n_times * n_freqs * n_tapers))
)
default_index.extend(["condition", "epoch"])
default_index.extend(["freq", "time"])
if unagg_mt:
name = "taper"
taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times)
mindex.append((name, taper_nums))
default_index.append(name)
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
# build DataFrame
if isinstance(self, EpochsTFR):
default_index = ["condition", "epoch", "freq", "time"]
else:
default_index = ["freq", "time"]
df = _build_data_frame(
self, data, picks, long_format, mindex, index, default_index=default_index
)
Expand Down
Loading