From 7071a0e24e121b28c851565f2e64a0128941e83a Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 4 Dec 2024 16:53:02 -0500 Subject: [PATCH] ENH: Add round-trip channel name saving (#13003) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Stefan Appelhoff --- doc/changes/devel/13003.newfeature.rst | 1 + mne/_fiff/_digitization.py | 29 +++++++++++++--- mne/_fiff/write.py | 6 +++- mne/channels/montage.py | 48 +++++++++++++++++++------- mne/channels/tests/test_montage.py | 47 +++++++++++++++++-------- mne/viz/tests/test_montage.py | 2 +- tutorials/forward/35_eeg_no_mri.py | 2 +- 7 files changed, 101 insertions(+), 34 deletions(-) create mode 100644 doc/changes/devel/13003.newfeature.rst diff --git a/doc/changes/devel/13003.newfeature.rst b/doc/changes/devel/13003.newfeature.rst new file mode 100644 index 00000000000..141265406a8 --- /dev/null +++ b/doc/changes/devel/13003.newfeature.rst @@ -0,0 +1 @@ +Added support for saving and loading channel names from FIF in :meth:`mne.channels.DigMontage.save` and :meth:`mne.channels.read_dig_fif` and added the convention that they should be saved as ``-dig.fif``, by `Eric Larson`_. diff --git a/mne/_fiff/_digitization.py b/mne/_fiff/_digitization.py index 6b14701d0b8..e55fd5d2dae 100644 --- a/mne/_fiff/_digitization.py +++ b/mne/_fiff/_digitization.py @@ -11,7 +11,7 @@ from .constants import FIFF, _coord_frame_named from .tag import read_tag from .tree import dir_tree_find -from .write import start_and_end_file, write_dig_points +from .write import _safe_name_list, start_and_end_file, write_dig_points _dig_kind_dict = { "cardinal": FIFF.FIFFV_POINT_CARDINAL, @@ -162,10 +162,11 @@ def __eq__(self, other): # noqa: D105 return np.allclose(self["r"], other["r"]) -def _read_dig_fif(fid, meas_info): +def _read_dig_fif(fid, meas_info, *, return_ch_names=False): """Read digitizer data from a FIFF file.""" isotrak = dir_tree_find(meas_info, FIFF.FIFFB_ISOTRAK) dig = None + ch_names = None if len(isotrak) == 0: logger.info("Isotrak not found") elif len(isotrak) > 1: @@ -183,13 +184,21 @@ def _read_dig_fif(fid, meas_info): elif kind == FIFF.FIFF_MNE_COORD_FRAME: tag = read_tag(fid, pos) coord_frame = _coord_frame_named.get(int(tag.data.item())) + elif kind == FIFF.FIFF_MNE_CH_NAME_LIST: + tag = read_tag(fid, pos) + ch_names = _safe_name_list(tag.data, "read", "ch_names") for d in dig: d["coord_frame"] = coord_frame - return _format_dig_points(dig) + out = _format_dig_points(dig) + if return_ch_names: + out = (out, ch_names) + return out @verbose -def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None): +def write_dig( + fname, pts, coord_frame=None, *, ch_names=None, overwrite=False, verbose=None +): """Write digitization data to a FIF file. Parameters @@ -203,6 +212,10 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None): If all the points have the same coordinate frame, specify the type here. Can be None (default) if the points could have varying coordinate frames. + ch_names : list of str | None + Channel names associated with the digitization points, if available. + + .. versionadded:: 1.9 %(overwrite)s .. versionadded:: 1.0 @@ -222,9 +235,15 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None): "Points have coord_frame entries that are incompatible with " f"coord_frame={coord_frame}: {tuple(bad_frames)}." ) + _validate_type(ch_names, (None, list, tuple), "ch_names") + if ch_names is not None: + for ci, ch_name in enumerate(ch_names): + _validate_type(ch_name, str, f"ch_names[{ci}]") with start_and_end_file(fname) as fid: - write_dig_points(fid, pts, block=True, coord_frame=coord_frame) + write_dig_points( + fid, pts, block=True, coord_frame=coord_frame, ch_names=ch_names + ) _cardinal_ident_mapping = { diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 427d4f12e54..1fc32f0163e 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -389,7 +389,7 @@ def write_ch_info(fid, ch): fid.write(b"\0" * (16 - len(ch_name))) -def write_dig_points(fid, dig, block=False, coord_frame=None): +def write_dig_points(fid, dig, block=False, coord_frame=None, *, ch_names=None): """Write a set of digitizer data points into a fif file.""" if dig is not None: data_size = 5 * 4 @@ -406,6 +406,10 @@ def write_dig_points(fid, dig, block=False, coord_frame=None): fid.write(np.array(d["kind"], ">i4").tobytes()) fid.write(np.array(d["ident"], ">i4").tobytes()) fid.write(np.array(d["r"][:3], ">f4").tobytes()) + if ch_names is not None: + write_name_list_sanitized( + fid, FIFF.FIFF_MNE_CH_NAME_LIST, ch_names, "ch_names" + ) if block: end_block(fid, FIFF.FIFFB_ISOTRAK) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 5dad657a75c..a6ded682de9 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -46,6 +46,7 @@ _on_missing, _pl, _validate_type, + check_fname, copy_function_doc_to_method_doc, fill_doc, verbose, @@ -409,9 +410,22 @@ def save(self, fname, *, overwrite=False, verbose=None): The filename to use. Should end in .fif or .fif.gz. %(overwrite)s %(verbose)s + + See Also + -------- + mne.channels.read_dig_fif + + Notes + ----- + .. versionchanged:: 1.9 + Added support for saving the associated channel names. """ + fname = _check_fname(fname, overwrite=overwrite) + check_fname(fname, "montage", ("-dig.fif", "-dig.fif.gz")) coord_frame = _check_get_coord_frame(self.dig) - write_dig(fname, self.dig, coord_frame, overwrite=overwrite) + write_dig( + fname, self.dig, coord_frame, overwrite=overwrite, ch_names=self.ch_names + ) def __iadd__(self, other): """Add two DigMontages in place. @@ -808,17 +822,15 @@ def read_dig_dat(fname): return make_dig_montage(electrodes, nasion, lpa, rpa) -def read_dig_fif(fname): +@verbose +def read_dig_fif(fname, *, verbose=None): r"""Read digitized points from a .fif file. - Note that electrode names are not present in the .fif file so - they are here defined with the convention from VectorView - systems (EEG001, EEG002, etc.) - Parameters ---------- fname : path-like FIF file from which to read digitization locations. + %(verbose)s Returns ------- @@ -835,17 +847,28 @@ def read_dig_fif(fname): read_dig_hpts read_dig_localite make_dig_montage + + Notes + ----- + .. versionchanged:: 1.9 + Added support for reading the associated channel names, if present. + + In some files, electrode names are not present (e.g., in older files). + For those files, the channel names are defined with the convention from + VectorView systems (EEG001, EEG002, etc.). """ - fname = _check_fname(fname, overwrite="read", must_exist=True) + check_fname(fname, "montage", ("-dig.fif", "-dig.fif.gz")) + fname = _check_fname(fname=fname, must_exist=True, overwrite="read") # Load the dig data f, tree = fiff_open(fname)[:2] with f as fid: - dig = _read_dig_fif(fid, tree) + dig, ch_names = _read_dig_fif(fid, tree, return_ch_names=True) - ch_names = [] - for d in dig: - if d["kind"] == FIFF.FIFFV_POINT_EEG: - ch_names.append(f"EEG{d['ident']:03d}") + if ch_names is None: # backward compat from when we didn't save the names + ch_names = [] + for d in dig: + if d["kind"] == FIFF.FIFFV_POINT_EEG: + ch_names.append(f"EEG{d['ident']:03d}") montage = DigMontage(dig=dig, ch_names=ch_names) return montage @@ -1572,6 +1595,7 @@ def read_custom_montage( -------- make_dig_montage make_standard_montage + read_dig_fif Notes ----- diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 6ec54a271c2..8add1398409 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -19,6 +19,7 @@ assert_equal, ) +import mne.channels.montage from mne import ( __file__ as _mne_file, ) @@ -56,6 +57,7 @@ _BUILTIN_STANDARD_MONTAGES, _check_get_coord_frame, transform_to_head, + write_dig, ) from mne.coreg import get_mni_fiducials from mne.datasets import testing @@ -138,7 +140,8 @@ def test_dig_montage_trans(tmp_path): _ensure_trans(trans) # ensure that we can save and load it, too fname = tmp_path / "temp-mon.fif" - _check_roundtrip(montage, fname, "mri") + with pytest.warns(RuntimeWarning, match="MNE naming conventions"): + _check_roundtrip(montage, fname, "mri") # test applying a trans position1 = montage.get_positions() montage.apply_trans(trans) @@ -1074,12 +1077,12 @@ def _ensure_fid_not_nan(info, ch_pos): @testing.requires_testing_data -def test_fif_dig_montage(tmp_path): +def test_fif_dig_montage(tmp_path, monkeypatch): """Test FIF dig montage support.""" - dig_montage = read_dig_fif(fif_dig_montage_fname) + dig_montage = read_dig_fif(fif_dig_montage_fname, verbose="error") # test round-trip IO - fname_temp = tmp_path / "test.fif" + fname_temp = tmp_path / "test-dig.fif" _check_roundtrip(dig_montage, fname_temp) # Make a BrainVision file like the one the user would have had @@ -1119,16 +1122,32 @@ def test_fif_dig_montage(tmp_path): # Roundtrip of non-FIF start montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), hpi=read_mrk(hpi)) elp_points = read_polhemus_fastscan(elp) - ch_pos = {f"EEG{k:03d}": pos for k, pos in enumerate(elp_points[8:], 1)} - montage += make_dig_montage( + ch_pos = {f"ECoG{k:03d}": pos for k, pos in enumerate(elp_points[3:], 1)} + assert len(elp_points) == 8 # there are only 8 but pretend the last are ECoG + other = make_dig_montage( nasion=elp_points[0], lpa=elp_points[1], rpa=elp_points[2], ch_pos=ch_pos ) + assert other.ch_names[0].startswith("ECoG") + montage += other + assert montage.ch_names[0].startswith("ECoG") _check_roundtrip(montage, fname_temp, "unknown") montage = transform_to_head(montage) _check_roundtrip(montage, fname_temp) montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_UNKNOWN with pytest.raises(RuntimeError, match="Only a single coordinate"): - montage.save(fname_temp) + montage.save(fname_temp, overwrite=True) + montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_HEAD + + # Check that old-style files can be read, too, using EEG001 etc. + def write_dig_no_ch_names(*args, **kwargs): + kwargs["ch_names"] = None + return write_dig(*args, **kwargs) + + monkeypatch.setattr(mne.channels.montage, "write_dig", write_dig_no_ch_names) + montage.save(fname_temp, overwrite=True) + montage_read = read_dig_fif(fname_temp) + default_ch_names = [f"EEG{ii:03d}" for ii in range(1, 6)] + assert montage_read.ch_names == default_ch_names @testing.requires_testing_data @@ -1175,8 +1194,8 @@ def test_egi_dig_montage(tmp_path): atol=1e-4, ) - # test round-trip IO - fname_temp = tmp_path / "egi_test.fif" + # test round-trip IO (with GZ) + fname_temp = tmp_path / "egi_test-dig.fif.gz" _check_roundtrip(dig_montage, fname_temp, "unknown") _check_roundtrip(dig_montage_in_head, fname_temp) @@ -1330,7 +1349,7 @@ def test_read_dig_captrak(tmp_path): ) montage = transform_to_head(montage) # transform_to_head has to be tested - _check_roundtrip(montage=montage, fname=str(tmp_path / "bvct_test.fif")) + _check_roundtrip(montage=montage, fname=tmp_path / "bvct_test-dig.fif") fid, _ = _get_fid_coords(montage.dig) assert_allclose( @@ -1495,15 +1514,15 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): assert_array_less(0, ang) # but not equal -# XXX: this does not check ch_names + it cannot work because of write_dig def _check_roundtrip(montage, fname, coord_frame="head"): """Check roundtrip writing.""" montage.save(fname, overwrite=True) montage_read = read_dig_fif(fname=fname) - assert_equal(repr(montage), repr(montage_read)) - assert_equal(_check_get_coord_frame(montage_read.dig), coord_frame) + assert repr(montage) == repr(montage_read) + assert _check_get_coord_frame(montage_read.dig) == coord_frame assert_dig_allclose(montage, montage_read) + assert montage.ch_names == montage_read.ch_names def test_digmontage_constructor_errors(): @@ -1910,7 +1929,7 @@ def test_get_montage(): # 4. read in BV test dataset and make sure montage # fulfills roundtrip on non-standard montage - dig_montage = read_dig_fif(fif_dig_montage_fname) + dig_montage = read_dig_fif(fif_dig_montage_fname, verbose="error") # Make a BrainVision file like the one the user would have had # with testing dataset 'test.vhdr' diff --git a/mne/viz/tests/test_montage.py b/mne/viz/tests/test_montage.py index e9954d2c115..c496dd50d32 100644 --- a/mne/viz/tests/test_montage.py +++ b/mne/viz/tests/test_montage.py @@ -48,7 +48,7 @@ def test_plot_montage(): assert "0 channels" in repr(montage) with pytest.raises(RuntimeError, match="No valid channel positions"): montage.plot() - d = read_dig_fif(fname=fif_fname) + d = read_dig_fif(fname=fif_fname, verbose="error") assert "61 channels" in repr(d) # XXX this is broken; dm.point_names is used. Sometimes we say this should # Just contain the HPI coils, other times that it's all channels (e.g., diff --git a/tutorials/forward/35_eeg_no_mri.py b/tutorials/forward/35_eeg_no_mri.py index 422fe8ea580..f1b4e6de6fd 100644 --- a/tutorials/forward/35_eeg_no_mri.py +++ b/tutorials/forward/35_eeg_no_mri.py @@ -105,7 +105,7 @@ # in the MRI coordinate frame, which can be used to compute the # MRI<->head transform ``trans``: fname_1020 = subjects_dir / subject / "montages" / "10-20-montage.fif" -mon = mne.channels.read_dig_fif(fname_1020) +mon = mne.channels.read_dig_fif(fname_1020, verbose="error") # should be named -dig.fif mon.rename_channels({f"EEG{ii:03d}": ch_name for ii, ch_name in enumerate(ch_names, 1)}) trans = mne.channels.compute_native_head_t(mon) raw.set_montage(mon)