diff --git a/mne_bids/tests/test_dig.py b/mne_bids/tests/test_dig.py index f5b7168eb..adea8e29e 100644 --- a/mne_bids/tests/test_dig.py +++ b/mne_bids/tests/test_dig.py @@ -142,26 +142,30 @@ def test_dig_template(tmp_path): for datatype in ("eeg", "ieeg"): (bids_root / "sub-01" / "ses-01" / datatype).mkdir(parents=True) + raw = _load_raw() + raw.pick(["eeg"]) + montage = raw.get_montage() + pos = montage.get_positions() + for datatype in ("eeg", "ieeg"): bids_path = _bids_path.copy().update(root=bids_root, datatype=datatype) for coord_frame in BIDS_STANDARD_TEMPLATE_COORDINATE_SYSTEMS: - raw = _load_raw() - raw.pick(["eeg"]) bids_path.update(space=coord_frame) - montage = raw.get_montage() - pos = montage.get_positions() + raw.set_montage(None) + _montage = montage.copy() mne_coord_frame = BIDS_TO_MNE_FRAMES.get(coord_frame, None) if mne_coord_frame is None: - montage.apply_trans(mne.transforms.Transform("head", "unknown")) + _montage.apply_trans(mne.transforms.Transform("head", "unknown")) else: - montage.apply_trans(mne.transforms.Transform("head", mne_coord_frame)) - _write_dig_bids(bids_path, raw, montage, acpc_aligned=True) + _montage.apply_trans(mne.transforms.Transform("head", mne_coord_frame)) + _write_dig_bids(bids_path, raw, _montage, acpc_aligned=True) electrodes_path = bids_path.copy().update( task=None, run=None, suffix="electrodes", extension=".tsv" ) coordsystem_path = bids_path.copy().update( task=None, run=None, suffix="coordsystem", extension=".json" ) + # _read_dig_bids updates raw inplace if mne_coord_frame is None: with pytest.warns( RuntimeWarning, match="not an MNE-Python coordinate frame" @@ -172,8 +176,7 @@ def test_dig_template(tmp_path): electrodes_path.update(space="fsaverage") coordsystem_path.update(space="fsaverage") _read_dig_bids(electrodes_path, coordsystem_path, datatype, raw) - montage2 = raw.get_montage() - pos2 = montage2.get_positions() + pos2 = raw.get_montage().get_positions() np.testing.assert_array_almost_equal( np.array(list(pos["ch_pos"].values())), np.array(list(pos2["ch_pos"].values())),