-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Add option to store and return TFR taper weights #12910
Open
tsbinns
wants to merge
39
commits into
mne-tools:main
Choose a base branch
from
tsbinns:add_tfr_weights
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
9fe1fb6
Add option to store and return tfr taper weights
tsbinns 45c6a0b
Merge remote-tracking branch 'upstream/main' into add_tfr_weights
tsbinns 82fc2f7
Update docstrings
tsbinns 9f30a59
Merge branch 'main' into add_tfr_weights
tsbinns a49f934
Remove whitespace
tsbinns 48afced
Merge branch 'add_tfr_weights' of https://github.com/tsbinns/mne-pyth…
tsbinns 7c3dcfa
Add PR num
tsbinns 8c16716
Revert "Update docstrings"
tsbinns 51b8cd0
Remove outdated default setting
tsbinns 2f9a4b4
Reapply "Update docstrings"
tsbinns b4537b2
Update docstrings
tsbinns f155238
Merge branch 'main' into add_tfr_weights
tsbinns 2a03e9b
Merge branch 'main' into add_tfr_weights
tsbinns 045d9a2
Merge branch 'main' into add_tfr_weights
tsbinns 8d645bb
Enforce return_weights as named param
tsbinns 5ad9bd5
Merge branch 'main' into add_tfr_weights
tsbinns 1c02b40
Add missing test coverage
tsbinns 54f2a32
Add changelog entry
tsbinns 6a23556
Merge branch 'fix_tfr_tapers' into fix_tfr_multitapers
tsbinns a107991
Begin add support for tapers in array objs
tsbinns 01c486c
Begin add support for tapers in array objs
tsbinns ca27179
Fix docstring entries
tsbinns b14a100
Fix faulty state check
tsbinns 972aba2
Add weights to AverageTFR
tsbinns e11fa2b
Expand test coverage
tsbinns aaef4b7
Merge branch 'main' into add_tfr_weights
tsbinns 999d122
Disallow aggregating tapers in combine_tfr
tsbinns e12b09a
Updated docstrings
tsbinns dd61955
Merge branch 'main' into add_tfr_weights
tsbinns 728701e
Add placeholder versionadded tags
tsbinns 6af3310
Merge branch 'add_tfr_weights' of https://github.com/tsbinns/mne-pyth…
tsbinns e3a3c4b
Merge remote-tracking branch 'upstream/main' into add_tfr_weights
tsbinns de39d25
Begin fixing to_data_frame
tsbinns 80126a7
Fix to_data_frame bug with tapers
tsbinns 82dfab9
Fix plotting with tapers
tsbinns 5b150aa
Merge branch 'main' into add_tfr_weights
tsbinns 0d3d85d
Merge branch 'main' into add_tfr_weights
tsbinns 012bd94
Add version tag
tsbinns e5eedee
Merge branch 'main' into add_tfr_weights
tsbinns File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Added the option to return taper weights from | ||
:func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the | ||
:class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -432,17 +432,21 @@ def test_tfr_morlet(): | |
def test_dpsswavelet(): | ||
"""Test DPSS tapers.""" | ||
freqs = np.arange(5, 25, 3) | ||
Ws = _make_dpss( | ||
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True | ||
Ws, weights = _make_dpss( | ||
1000, | ||
freqs=freqs, | ||
n_cycles=freqs / 2.0, | ||
time_bandwidth=4.0, | ||
zero_mean=True, | ||
return_weights=True, | ||
) | ||
|
||
assert len(Ws) == 3 # 3 tapers expected | ||
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected | ||
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) | ||
|
||
# Check that zero mean is true | ||
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 | ||
|
||
assert len(Ws[0]) == len(freqs) # As many wavelets as asked for | ||
|
||
|
||
@pytest.mark.slowtest | ||
def test_tfr_multitaper(): | ||
|
@@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): | |
with tfr.info._unlock(): | ||
tfr.info["meas_date"] = want | ||
assert tfr_loaded == tfr | ||
# test with taper dimension and weights | ||
n_tapers = 3 # anything >= 1 should do | ||
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs | ||
state = tfr.__getstate__() | ||
state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim | ||
state["weights"] = weights # add weights | ||
state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims | ||
tfr = EpochsTFR(inst=state) | ||
tfr.save(fname, overwrite=True) | ||
tfr_loaded = read_tfrs(fname) | ||
assert tfr_loaded == tfr | ||
# test overwrite | ||
with pytest.raises(OSError, match="Destination file exists."): | ||
tfr.save(fname, overwrite=False) | ||
|
@@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked): | |
AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) | ||
|
||
|
||
def test_epochstfr_init_errors(epochs_tfr): | ||
"""Test __init__ for EpochsTFR.""" | ||
state = epochs_tfr.__getstate__() | ||
with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): | ||
EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) | ||
@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) | ||
def test_tfr_init_errors(inst, request, average_tfr): | ||
"""Test __init__ for {Raw,Epochs,Average}TFR.""" | ||
# Load data | ||
inst = _get_inst(inst, request, average_tfr=average_tfr) | ||
state = inst.__getstate__() | ||
# Prepare for TFRArray object instantiation | ||
inst_name = inst.__class__.__name__ | ||
class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) | ||
ndims_mapping = dict( | ||
RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") | ||
) | ||
TFR = class_mapping[inst_name] | ||
allowed_ndims = ndims_mapping[inst_name] | ||
# Check errors caught | ||
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): | ||
TFR(inst=state | dict(data=inst.data[..., 0])) | ||
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): | ||
TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) | ||
with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): | ||
EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) | ||
TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) | ||
with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): | ||
EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) | ||
TFR(inst=state | dict(times=inst.times[:-1])) | ||
with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): | ||
EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) | ||
TFR(inst=state | dict(freqs=inst.freqs[:-1])) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -830,6 +859,25 @@ def test_plot(): | |
plt.close("all") | ||
|
||
|
||
@pytest.mark.parametrize("output", ("complex", "phase")) | ||
def test_plot_multitaper_complex_phase(output): | ||
"""Test TFR plotting of data with a taper dimension.""" | ||
# Create example data with a taper dimension | ||
n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) | ||
data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) | ||
if output == "complex": | ||
data = data + np.random.rand(*data.shape) * 1j # add imaginary data | ||
times = np.arange(n_times) | ||
freqs = np.arange(n_freqs) | ||
weights = np.random.rand(n_tapers, n_freqs) | ||
info = mne.create_info(n_chans, 1000.0, "eeg") | ||
tfr = AverageTFRArray( | ||
info=info, data=data, times=times, freqs=freqs, weights=weights | ||
) | ||
# Check that plotting works | ||
tfr.plot() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"timefreqs,title,combine", | ||
( | ||
|
@@ -1154,6 +1202,15 @@ def test_averaging_epochsTFR(): | |
): | ||
power.average(method=np.mean) | ||
|
||
# Check it doesn't run for taper spectra | ||
tapered = epochs.compute_tfr( | ||
method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" | ||
) | ||
with pytest.raises( | ||
NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." | ||
): | ||
tapered.average() | ||
|
||
|
||
def test_averaging_freqsandtimes_epochsTFR(): | ||
"""Test that EpochsTFR averaging freqs methods work.""" | ||
|
@@ -1258,12 +1315,15 @@ def test_to_data_frame(): | |
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] | ||
n_picks = len(ch_names) | ||
ch_types = ["eeg"] * n_picks | ||
n_tapers = 2 | ||
n_freqs = 5 | ||
n_times = 6 | ||
data = np.random.rand(n_epos, n_picks, n_freqs, n_times) | ||
times = np.arange(6) | ||
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) | ||
times = np.arange(n_times) | ||
srate = 1000.0 | ||
freqs = np.arange(5) | ||
freqs = np.arange(n_freqs) | ||
tapers = np.arange(n_tapers) | ||
weights = np.ones((n_tapers, n_freqs)) | ||
events = np.zeros((n_epos, 3), dtype=int) | ||
events[:, 0] = np.arange(n_epos) | ||
events[:, 2] = np.arange(5, 5 + n_epos) | ||
|
@@ -1276,6 +1336,7 @@ def test_to_data_frame(): | |
freqs=freqs, | ||
events=events, | ||
event_id=event_id, | ||
weights=weights, | ||
) | ||
# test index checking | ||
with pytest.raises(ValueError, match="options. Valid index options are"): | ||
|
@@ -1287,32 +1348,51 @@ def test_to_data_frame(): | |
# test wide format | ||
df_wide = tfr.to_data_frame() | ||
assert all(np.isin(tfr.ch_names, df_wide.columns)) | ||
assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) | ||
assert all( | ||
np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) | ||
) | ||
# test long format | ||
df_long = tfr.to_data_frame(long_format=True) | ||
expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") | ||
expected = ( | ||
"condition", | ||
"epoch", | ||
"freq", | ||
"time", | ||
"channel", | ||
"ch_type", | ||
"value", | ||
"taper", | ||
) | ||
assert set(expected) == set(df_long.columns) | ||
assert set(tfr.ch_names) == set(df_long["channel"]) | ||
assert len(df_long) == tfr.data.size | ||
# test long format w/ index | ||
df_long = tfr.to_data_frame(long_format=True, index=["freq"]) | ||
del df_wide, df_long | ||
# test whether data is in correct shape | ||
df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) | ||
df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) | ||
data = tfr.data | ||
assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) | ||
# compare arbitrary observation: | ||
assert ( | ||
df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] | ||
== data[1, 3, 1, 2] | ||
df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] | ||
== data[1, 3, 1, 1, 2] | ||
) | ||
|
||
# Check also for AverageTFR: | ||
# (remove taper dimension before averaging) | ||
state = tfr.__getstate__() | ||
state["data"] = state["data"][:, :, 0] | ||
state["dims"] = ("epoch", "channel", "freq", "time") | ||
state["weights"] = None | ||
tfr = EpochsTFR(inst=state) | ||
tfr = tfr.average() | ||
with pytest.raises(ValueError, match="options. Valid index options are"): | ||
tfr.to_data_frame(index=["epoch", "condition"]) | ||
with pytest.raises(ValueError, match='"epoch" is not a valid option'): | ||
tfr.to_data_frame(index="epoch") | ||
with pytest.raises(ValueError, match='"taper" is not a valid option'): | ||
tfr.to_data_frame(index="taper") | ||
with pytest.raises(TypeError, match="index must be `None` or a string "): | ||
tfr.to_data_frame(index=np.arange(400)) | ||
# test wide format | ||
|
@@ -1348,11 +1428,13 @@ def test_to_data_frame_index(index): | |
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] | ||
n_picks = len(ch_names) | ||
ch_types = ["eeg"] * n_picks | ||
n_tapers = 2 | ||
n_freqs = 5 | ||
n_times = 6 | ||
data = np.random.rand(n_epos, n_picks, n_freqs, n_times) | ||
times = np.arange(6) | ||
freqs = np.arange(5) | ||
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) | ||
times = np.arange(n_times) | ||
freqs = np.arange(n_freqs) | ||
weights = np.ones((n_tapers, n_freqs)) | ||
events = np.zeros((n_epos, 3), dtype=int) | ||
events[:, 0] = np.arange(n_epos) | ||
events[:, 2] = np.arange(5, 8) | ||
|
@@ -1365,14 +1447,15 @@ def test_to_data_frame_index(index): | |
freqs=freqs, | ||
events=events, | ||
event_id=event_id, | ||
weights=weights, | ||
) | ||
df = tfr.to_data_frame(picks=[0, 2, 3], index=index) | ||
# test index order/hierarchy preservation | ||
if not isinstance(index, list): | ||
index = [index] | ||
assert list(df.index.names) == index | ||
# test that non-indexed data were present as columns | ||
non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) | ||
non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) | ||
if len(non_index): | ||
assert all(np.isin(non_index, df.columns)) | ||
|
||
|
@@ -1538,7 +1621,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): | |
def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): | ||
"""Test Epochs.compute_tfr(output="complex"/"phase").""" | ||
tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) | ||
assert len(tfr.shape) == 5 | ||
assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time | ||
assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match | ||
|
||
|
||
@pytest.mark.parametrize("copy", (False, True)) | ||
|
@@ -1550,6 +1634,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): | |
assert avgs[0].comment == str(epochs_tfr.events[0, -1]) | ||
|
||
|
||
@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked")) | ||
def test_tfrarray_tapered_spectra(obj_type): | ||
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" | ||
# Create example data with a taper dimension | ||
n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6) | ||
data_shape = (n_chans, n_tapers, n_freqs, n_times) | ||
if obj_type == "epochs": | ||
data_shape = (n_epochs,) + data_shape | ||
data = np.random.rand(*data_shape) | ||
times = np.arange(n_times) | ||
freqs = np.arange(n_freqs) | ||
weights = np.random.rand(n_tapers, n_freqs) | ||
info = mne.create_info(n_chans, 1000.0, "eeg") | ||
# Prepare for TFRArray object instantiation | ||
defaults = dict(info=info, data=data, times=times, freqs=freqs) | ||
class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray) | ||
TFRArray = class_mapping[obj_type] | ||
# Check TFRArray instantiation runs with good data | ||
TFRArray(**defaults, weights=weights) | ||
# Check taper dimension but no weights caught | ||
with pytest.raises( | ||
ValueError, match="Taper dimension in data, but no weights found." | ||
): | ||
TFRArray(**defaults) | ||
# Check mismatching n_taper in weights caught | ||
with pytest.raises( | ||
ValueError, match=r"Taper axis .* doesn't match weights attribute" | ||
): | ||
TFRArray(**defaults, weights=weights[:-1]) | ||
# Check mismatching n_freq in weights caught | ||
with pytest.raises( | ||
ValueError, match=r"Frequency axis .* doesn't match weights attribute" | ||
): | ||
TFRArray(**defaults, weights=weights[:, :-1]) | ||
|
||
|
||
def test_tfr_proj(epochs): | ||
"""Test `compute_tfr(proj=True)`.""" | ||
epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) | ||
|
@@ -1731,3 +1851,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): | |
assert re.match( | ||
rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("output", ("complex", "phase")) | ||
def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): | ||
"""Test plot_joint/topo/topomap() for data with a taper dimension.""" | ||
# Compute TFR with taper dimension | ||
tfr = evoked.compute_tfr( | ||
method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output | ||
) | ||
# Check that plotting works | ||
tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed | ||
tfr.plot_topo() | ||
tfr.plot_topomap() | ||
Comment on lines
+1856
to
+1866
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basic test that just checks whether the code runs, but it covers the lines where changes to topo-related plotting were made, and other tests deal with non-default method params. |
||
|
||
|
||
def test_combine_tfr_error_catch(average_tfr): | ||
"""Test combine_tfr() catches errors.""" | ||
# check unrecognised weights string caught | ||
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): | ||
combine_tfr([average_tfr, average_tfr], weights="foo") | ||
# check bad weights size caught | ||
with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): | ||
combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) | ||
# check different channel names caught | ||
state = average_tfr.__getstate__() | ||
new_info = average_tfr.info.copy() | ||
average_tfr_bad = AverageTFR( | ||
inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) | ||
) | ||
with pytest.raises(AssertionError, match=".* do not contain the same channels"): | ||
combine_tfr([average_tfr, average_tfr_bad]) | ||
# check different times caught | ||
average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) | ||
with pytest.raises( | ||
AssertionError, match=".* do not contain the same time instants" | ||
): | ||
combine_tfr([average_tfr, average_tfr_bad]) | ||
# check taper dim caught | ||
n_tapers = 3 # anything >= 1 should do | ||
weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs | ||
state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) | ||
state["weights"] = weights | ||
state["dims"] = ("channel", "taper", "freq", "time") | ||
average_tfr_taper = AverageTFR(inst=state) | ||
with pytest.raises( | ||
NotImplementedError, | ||
match="Aggregating multitaper tapers across TFR datasets is not supported.", | ||
): | ||
combine_tfr([average_tfr_taper, average_tfr_taper]) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, a pretty basic test that just checks whether plotting code runs, but covers the changes and non-default params tested elswehere.