Skip to content

Commit

Permalink
Fix plotting with tapers
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Dec 14, 2024
1 parent 80126a7 commit 82dfab9
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 88 deletions.
64 changes: 48 additions & 16 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,25 @@ def test_plot():
plt.close("all")


@pytest.mark.parametrize("output", ("complex", "phase"))
def test_plot_multitaper_complex_phase(output):
"""Test TFR plotting of data with a taper dimension."""
# Create example data with a taper dimension
n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3)
data = np.random.rand(n_chans, n_tapers, n_freqs, n_times)
if output == "complex":
data = data + np.random.rand(*data.shape) * 1j # add imaginary data
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.random.rand(n_tapers, n_freqs)
info = mne.create_info(n_chans, 1000.0, "eeg")
tfr = AverageTFRArray(
info=info, data=data, times=times, freqs=freqs, weights=weights
)
# Check that plotting works
tfr.plot()


@pytest.mark.parametrize(
"timefreqs,title,combine",
(
Expand Down Expand Up @@ -1611,23 +1630,23 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
assert avgs[0].comment == str(epochs_tfr.events[0, -1])


@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked"))
def test_tfrarray_tapered_spectra(inst, evoked, request):
@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked"))
def test_tfrarray_tapered_spectra(obj_type):
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
# Load data object
inst = _get_inst(inst, request, evoked=evoked)
inst.pick("mag")
# Compute TFR with taper dimension (can be complex or phase output)
tfr = inst.compute_tfr(
method="multitaper", freqs=freqs_linspace, n_cycles=4, output="complex"
)
tfr_array, weights = tfr.get_data(), tfr.weights
# Create example data with a taper dimension
n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6)
data_shape = (n_chans, n_tapers, n_freqs, n_times)
if obj_type == "epochs":
data_shape = (n_epochs,) + data_shape
data = np.random.rand(*data_shape)
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.random.rand(n_tapers, n_freqs)
info = mne.create_info(n_chans, 1000.0, "eeg")
# Prepare for TFRArray object instantiation
defaults = dict(
info=inst.info, data=tfr_array, times=inst.times, freqs=freqs_linspace
)
class_mapping = dict(Raw=RawTFRArray, Epochs=EpochsTFRArray, Evoked=AverageTFRArray)
TFRArray = class_mapping[inst.__class__.__name__]
defaults = dict(info=info, data=data, times=times, freqs=freqs)
class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray)
TFRArray = class_mapping[obj_type]
# Check TFRArray instantiation runs with good data
TFRArray(**defaults, weights=weights)
# Check taper dimension but no weights caught
Expand Down Expand Up @@ -1830,7 +1849,20 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
)


def test_combine_tfr_error_catch(request, average_tfr):
@pytest.mark.parametrize("output", ("complex", "phase"))
def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked):
"""Test plot_joint/topo/topomap() for data with a taper dimension."""
# Compute TFR with taper dimension
tfr = evoked.compute_tfr(
method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output
)
# Check that plotting works
tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed
tfr.plot_topo()
tfr.plot_topomap()


def test_combine_tfr_error_catch(average_tfr):
"""Test combine_tfr() catches errors."""
# check unrecognised weights string caught
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'):
Expand Down
121 changes: 51 additions & 70 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,7 @@ def _onselect(
fmax=fmax,
baseline=baseline,
mode=mode,
taper_weights=self.weights,
verbose=verbose,
)
# average over times and freqs
Expand Down Expand Up @@ -2026,6 +2027,7 @@ def plot(
baseline=baseline,
mode=mode,
dB=dB,
taper_weights=self.weights,
verbose=verbose,
)
# shape
Expand All @@ -2036,6 +2038,9 @@ def plot(
want_shape[ch_axis] = len(idx_picks) if combine is None else 1
want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping
want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping
want_shape = [
n for i, n in enumerate(want_shape) if self._dims[i] != "taper"
] # tapers must be aggregated over by now
want_shape = tuple(want_shape)
# combine
combine_was_none = combine is None
Expand Down Expand Up @@ -2379,6 +2384,7 @@ def plot_joint(
fmax=_fmax,
baseline=baseline,
mode=mode,
taper_weights=self.weights,
verbose=verbose,
)
_data = _data.mean(axis=(-1, -2)) # avg over times and freqs
Expand Down Expand Up @@ -2527,23 +2533,23 @@ def plot_topo(
info, data = _prepare_picks(info, data, picks, axis=0)
del picks

# TODO this is the only remaining call to _preproc_tfr; should be refactored
# (to use _prep_data_for_plot?)
data, times, freqs, vmin, vmax = _preproc_tfr(
# baseline, crop, convert complex to power, aggregate tapers, and dB scaling
data, times, freqs = _prep_data_for_plot(
data,
times,
freqs,
tmin,
tmax,
fmin,
fmax,
mode,
baseline,
vmin,
vmax,
dB,
info["sfreq"],
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
dB=dB,
taper_weights=self.weights,
verbose=verbose,
)
# get vlims
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)

if layout is None:
from mne import find_layout
Expand Down Expand Up @@ -4054,62 +4060,6 @@ def _centered(arr, newsize):
return arr[tuple(myslice)]


def _preproc_tfr(
data,
times,
freqs,
tmin,
tmax,
fmin,
fmax,
mode,
baseline,
vmin,
vmax,
dB,
sfreq,
copy=None,
):
"""Aux Function to prepare tfr computation."""
if copy is None:
copy = baseline is not None
data = rescale(data, times, baseline, mode, copy=copy)

if np.iscomplexobj(data):
# complex amplitude → real power (for plotting); if data are
# real-valued they should already be power
data = (data * data.conj()).real

# crop time
itmin, itmax = None, None
idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0]
if tmin is not None:
itmin = idx[0]
if tmax is not None:
itmax = idx[-1] + 1

times = times[itmin:itmax]

# crop freqs
ifmin, ifmax = None, None
idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0]
if fmin is not None:
ifmin = idx[0]
if fmax is not None:
ifmax = idx[-1] + 1

freqs = freqs[ifmin:ifmax]

# crop data
data = data[:, ifmin:ifmax, itmin:itmax]

if dB:
data = 10 * np.log10(data)

vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
return data, times, freqs, vmin, vmax


def _ensure_slice(decim):
"""Aux function checking the decim parameter."""
_validate_type(decim, ("int-like", slice), "decim")
Expand Down Expand Up @@ -4344,6 +4294,7 @@ def _prep_data_for_plot(
baseline=None,
mode=None,
dB=False,
taper_weights=None,
verbose=None,
):
# baseline
Expand All @@ -4357,9 +4308,39 @@ def _prep_data_for_plot(
freqs = freqs[freq_mask]
# crop data
data = data[..., freq_mask, :][..., time_mask]
# complex amplitude → real power; real-valued data is already power (or ITC)
# handle unaggregated multitaper (complex or phase multitaper data)
if taper_weights is not None: # assumes a taper dimension
logger.info("Aggregating multitaper estimates before plotting...")
if np.iscomplexobj(data): # complex coefficients → power
data = _tfr_from_mt(data, taper_weights)
else: # tapered phase data → weighted phase data
data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1)
# handle remaining complex amplitude → real power
if np.iscomplexobj(data):
data = (data * data.conj()).real
if dB:
data = 10 * np.log10(data)
return data, times, freqs


def _tfr_from_mt(x_mt, weights):
"""Aggregate complex multitaper coefficients over tapers and convert to power.
Parameters
----------
x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times)
The complex-valued multitaper coefficients.
weights : array, shape (n_tapers, n_freqs)
The weights to use to combine the tapered estimates.
Returns
-------
tfr : array, shape (n_channels, n_freqs, n_times)
The time-frequency power estimates.
"""
weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims
tfr = weights * x_mt
tfr *= tfr.conj()
tfr = tfr.real.sum(axis=1)
tfr *= 2 / (weights * weights.conj()).real.sum(axis=1)
return tfr
25 changes: 24 additions & 1 deletion mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
compute_bridged_electrodes,
compute_current_source_density,
)
from mne.time_frequency.tfr import AverageTFRArray
from mne.time_frequency.tfr import AverageTFR, AverageTFRArray
from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap
from mne.viz.tests.test_raw import _proj_status
from mne.viz.topomap import (
Expand Down Expand Up @@ -610,6 +610,29 @@ def test_plot_tfr_topomap():
ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0
)

# test data with taper dimension (real)
data = np.expand_dims(data, axis=1)
weights = np.random.rand(1, n_freqs)
tfr = AverageTFRArray(
info=info,
data=data,
times=times,
freqs=np.arange(n_freqs),
nave=nave,
weights=weights,
)
tfr.plot_topomap(
ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0
)
# test data with taper dimension (complex)
state = tfr.__getstate__()
tfr = AverageTFR(inst=state | dict(data=data * (1 + 1j)))
tfr.plot_topomap(
ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0
)
# remove taper dim before proceeding
data = data[:, 0]

# test real numbers
tfr = AverageTFRArray(
info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave
Expand Down
14 changes: 13 additions & 1 deletion mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,14 +1882,26 @@ def plot_tfr_topomap(
tfr, ch_type, sphere=sphere
)
outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
data = tfr.data[picks, :, :]
data = tfr.data[picks]

# merging grads before rescaling makes ERDs visible
if merge_channels:
data, names = _merge_ch_data(data, ch_type, names, method="mean")

data = rescale(data, tfr.times, baseline, mode, copy=True)

# handle unaggregated multitaper (complex or phase multitaper data)
if tfr.weights is not None: # assumes a taper dimension
logger.info("Aggregating multitaper estimates before plotting...")
weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims
data = weights * data
if np.iscomplexobj(data): # complex coefficients → power
data *= data.conj()
data = data.real.sum(axis=1)
data *= 2 / (weights * weights.conj()).real.sum(axis=1)
else: # tapered phase data → weighted phase data
data = data.mean(axis=1)
# handle remaining complex amplitude → real power
if np.iscomplexobj(data):
data = np.sqrt((data * data.conj()).real)

Expand Down

0 comments on commit 82dfab9

Please sign in to comment.