Skip to content

Commit

Permalink
rework test
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Aug 22, 2024
1 parent 346e3ce commit a49d2cd
Showing 1 changed file with 60 additions and 104 deletions.
164 changes: 60 additions & 104 deletions mne/stats/tests/test_cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,121 +909,77 @@ def test_new_cluster_api(Inst):
"""Test handling different MNE objects in the cluster API."""
pd = pytest.importorskip("pandas")

n_subs, n_epo, n_chan, n_freq, n_times = 2, 2, 3, 4, 5
rng = np.random.default_rng(seed=8675309)
is_epo = Inst in (EpochsTFRArray, EpochsArray)
is_tfr = Inst in (EpochsTFRArray, AverageTFRArray)

n_epo, n_chan, n_freq, n_times = 6, 3, 4, 5

# prepare the dimensions of the simulated data, then simulate
size = (n_chan,)
if is_epo:
size = (n_epo, *size)
if is_tfr:
size = (*size, n_freq)
size = (*size, n_times)
data = rng.normal(size=size)

# construct the instance
info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg")
# Introduce a significant difference in a specific region, time, and frequency
region_start = 1
region_end = 2
time_start = 2
time_end = 4
freq_start = 2
freq_end = 4

if Inst == EpochsArray:
# Create random data for EpochsArray
inst1 = Inst(np.random.randn(n_epo, n_chan, n_times), info=info)
# Adding a constant to create a difference
data_copy = inst1.get_data().copy() # no data attribute for EpochsArray
data_copy[:, region_start:region_end, time_start:time_end] += (
2 # Modify the copy
)
inst2 = Inst(
data=data_copy, info=info
) # Use the modified copy as a new instance

elif Inst == EvokedArray:
# Create random data for EvokedArray
inst1 = Inst(np.random.randn(n_chan, n_times), info=info)
data_copy = inst1.data.copy()
data_copy[region_start:region_end, time_start:time_end] += 2
inst2 = Inst(data=data_copy, info=info)

elif Inst == EpochsTFRArray:
# Create random data for EpochsTFRArray
data_tfr1 = np.random.randn(n_epo, n_chan, n_freq, n_times)
data_tfr2 = np.random.randn(n_epo, n_chan, n_freq, n_times)
inst1 = Inst(
data=data_tfr1, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
data_tfr2 = inst2.data.copy()
data_tfr2[
:, region_start:region_end, freq_start:freq_end, time_start:time_end
] += 2
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)

elif Inst == AverageTFRArray:
# Create random data for AverageTFRArray
data_tfr1 = np.random.randn(n_chan, n_freq, n_times)
data_tfr2 = np.random.randn(n_chan, n_freq, n_times)
inst1 = Inst(
data=data_tfr1, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
data_tfr2 = inst2.data.copy()
data_tfr2[
region_start:region_end, freq_start:freq_end, time_start:time_end
] += 2
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)

if Inst == EvokedArray or Inst == AverageTFRArray:
# Generate random noise
noise = np.random.normal(loc=0, scale=0.1, size=inst1.data.shape)
# add noise to the data of the second subject
inst1_n = inst1.copy()
inst1_n.data = inst1.data + noise
inst2_n = inst2.copy()
inst2_n.data = inst2.data + noise
data = [inst1, inst2, inst1_n, inst2_n]
conds = ["a", "b"] * n_subs
kw = dict(times=np.arange(n_times), freqs=np.arange(n_freq)) if is_tfr else dict()
cond_a = Inst(data=data, info=info, **kw)
cond_b = cond_a.copy()
# introduce a significant difference in a specific region, time, and frequency
ch_start, ch_end = 0, 2 # 2 channels
t_start, t_end = 2, 4 # 2 times
f_start, f_end = 2, 4 # 2 freqs
if is_tfr:
cond_b._data[..., ch_start:ch_end, f_start:f_end, t_start:t_end] += 2
else:
cond_b._data[..., ch_start:ch_end, t_start:t_end] += 2
# for Evokeds/AverageTFRs, we create fake "subjects" as our observations within each
# condition. We add a bit of noise while we do so.
if not is_epo:
insts = list()
for cond in cond_a, cond_b:
for _n in range(n_epo):
if not _n:
insts.append(cond)
continue
_cond = cond.copy()
_cond.data += rng.normal(scale=0.1, size=_cond.data.shape)
insts.append(_cond)
conds = np.repeat(["a", "b"], n_epo).tolist()
else:
data = [inst1, inst2]
# For Epochs(TFR)Array, each epoch is an observation and they're already
# noisy/non-identical, so no duplication / noise-addition necessary.
insts = [cond_a, cond_b]
conds = ["a", "b"]

df = pd.DataFrame(dict(data=data, condition=conds))

# run new clustering API
df = pd.DataFrame(dict(data=insts, condition=conds))
kwargs = dict(
n_permutations=100, seed=42, tail=1, buffer_size=None, out_type="mask"
)

result_new_api = cluster_test(df, "data~condition", **kwargs)

# make sure channels are last dimension for old API
if Inst == EpochsArray:
inst1 = inst1.get_data().transpose(0, 2, 1)
inst2 = inst2.get_data().transpose(0, 2, 1)
elif Inst == EpochsTFRArray:
inst1 = inst1.data.transpose(0, 3, 2, 1)
inst2 = inst2.data.transpose(0, 3, 2, 1)
elif Inst == AverageTFRArray:
inst1 = inst1.data.transpose(2, 1, 0)
inst2 = inst2.data.transpose(2, 1, 0)
inst1_n = inst1_n.data.transpose(2, 1, 0)
inst2_n = inst2_n.data.transpose(2, 1, 0)
# combine the data of the two subjects
inst1 = np.concatenate([inst1[np.newaxis, :], inst1_n[np.newaxis, :]], axis=0)
inst2 = np.concatenate([inst2[np.newaxis, :], inst2_n[np.newaxis, :]], axis=0)
if is_epo:
axes = (0, 3, 2, 1) if is_tfr else (0, 2, 1)
X = [cond_a.get_data().transpose(*axes), cond_b.get_data().transpose(*axes)]
else:
inst1 = inst1.data.transpose(1, 0)
inst2 = inst2.data.transpose(1, 0)
inst1_n = inst1_n.data.transpose(1, 0)
inst2_n = inst2_n.data.transpose(1, 0)
# combine the data of the two subjects
inst1 = np.concatenate([inst1[np.newaxis, :], inst1_n[np.newaxis, :]], axis=0)
inst2 = np.concatenate([inst2[np.newaxis, :], inst2_n[np.newaxis, :]], axis=0)

F_obs, clusters, cluster_pvals, H0 = permutation_cluster_test(
[inst1, inst2], **kwargs
)
axes = (2, 1, 0) if is_tfr else (1, 0)
Xa = list()
Xb = list()
for inst, cond in zip(insts, conds):
container = Xa if cond == "a" else Xb
container.append(inst.get_data().transpose(*axes))
X = [np.stack(Xa), np.stack(Xb)]

F_obs, clusters, cluster_pvals, H0 = permutation_cluster_test(X, **kwargs)
assert_array_almost_equal(result_new_api.H0, H0)
assert_array_almost_equal(result_new_api.stat_obs, F_obs)
assert_array_almost_equal(result_new_api.cluster_p_values, cluster_pvals)
assert result_new_api.clusters == clusters
assert len(result_new_api.clusters) == len(clusters)
for clu1, clu2 in zip(result_new_api.clusters, clusters):
assert_array_equal(clu1, clu2)

0 comments on commit a49d2cd

Please sign in to comment.