From 82fc2f7fe450dee8445cb9b48993944336e2aedc Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:24:33 +0200 Subject: [PATCH] Update docstrings --- mne/time_frequency/multitaper.py | 5 ++-- mne/time_frequency/tfr.py | 50 +++++++++++++++++++------------- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index f5f6f79a0b3..fc926af4863 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -505,7 +505,8 @@ def tfr_array_multitaper( coherence across trials. return_weights : bool, default False - If True, return the taper weights. Only applies if ``output="complex"``. + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. .. versionadded:: 1.9.0 @@ -528,7 +529,7 @@ def tfr_array_multitaper( contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. weights : array of shape (n_tapers, n_freqs) - The taper weights. Only returned if ``output="complex"`` and + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and ``return_weights=True``. See Also diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index e9d028e7e50..dd64c18d9e3 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1215,9 +1215,6 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) - # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) - if method == "morlet": - method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -2697,9 +2694,12 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) @@ -2707,32 +2707,42 @@ def to_data_frame( # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + default_index.extend(["freq", "time"]) + if unagg_mt: + name = "taper" + taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times) + mindex.append((name, taper_nums)) + default_index.append(name) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index )