Skip to content

Commit

Permalink
ENH: Add round-trip channel name saving (mne-tools#13003)
Browse files Browse the repository at this point in the history
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Stefan Appelhoff <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent a1a05ae commit 7071a0e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 34 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/13003.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for saving and loading channel names from FIF in :meth:`mne.channels.DigMontage.save` and :meth:`mne.channels.read_dig_fif` and added the convention that they should be saved as ``-dig.fif``, by `Eric Larson`_.
29 changes: 24 additions & 5 deletions mne/_fiff/_digitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .constants import FIFF, _coord_frame_named
from .tag import read_tag
from .tree import dir_tree_find
from .write import start_and_end_file, write_dig_points
from .write import _safe_name_list, start_and_end_file, write_dig_points

_dig_kind_dict = {
"cardinal": FIFF.FIFFV_POINT_CARDINAL,
Expand Down Expand Up @@ -162,10 +162,11 @@ def __eq__(self, other): # noqa: D105
return np.allclose(self["r"], other["r"])


def _read_dig_fif(fid, meas_info):
def _read_dig_fif(fid, meas_info, *, return_ch_names=False):
"""Read digitizer data from a FIFF file."""
isotrak = dir_tree_find(meas_info, FIFF.FIFFB_ISOTRAK)
dig = None
ch_names = None
if len(isotrak) == 0:
logger.info("Isotrak not found")
elif len(isotrak) > 1:
Expand All @@ -183,13 +184,21 @@ def _read_dig_fif(fid, meas_info):
elif kind == FIFF.FIFF_MNE_COORD_FRAME:
tag = read_tag(fid, pos)
coord_frame = _coord_frame_named.get(int(tag.data.item()))
elif kind == FIFF.FIFF_MNE_CH_NAME_LIST:
tag = read_tag(fid, pos)
ch_names = _safe_name_list(tag.data, "read", "ch_names")
for d in dig:
d["coord_frame"] = coord_frame
return _format_dig_points(dig)
out = _format_dig_points(dig)
if return_ch_names:
out = (out, ch_names)
return out


@verbose
def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None):
def write_dig(
fname, pts, coord_frame=None, *, ch_names=None, overwrite=False, verbose=None
):
"""Write digitization data to a FIF file.
Parameters
Expand All @@ -203,6 +212,10 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None):
If all the points have the same coordinate frame, specify the type
here. Can be None (default) if the points could have varying
coordinate frames.
ch_names : list of str | None
Channel names associated with the digitization points, if available.
.. versionadded:: 1.9
%(overwrite)s
.. versionadded:: 1.0
Expand All @@ -222,9 +235,15 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None):
"Points have coord_frame entries that are incompatible with "
f"coord_frame={coord_frame}: {tuple(bad_frames)}."
)
_validate_type(ch_names, (None, list, tuple), "ch_names")
if ch_names is not None:
for ci, ch_name in enumerate(ch_names):
_validate_type(ch_name, str, f"ch_names[{ci}]")

with start_and_end_file(fname) as fid:
write_dig_points(fid, pts, block=True, coord_frame=coord_frame)
write_dig_points(
fid, pts, block=True, coord_frame=coord_frame, ch_names=ch_names
)


_cardinal_ident_mapping = {
Expand Down
6 changes: 5 additions & 1 deletion mne/_fiff/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def write_ch_info(fid, ch):
fid.write(b"\0" * (16 - len(ch_name)))


def write_dig_points(fid, dig, block=False, coord_frame=None):
def write_dig_points(fid, dig, block=False, coord_frame=None, *, ch_names=None):
"""Write a set of digitizer data points into a fif file."""
if dig is not None:
data_size = 5 * 4
Expand All @@ -406,6 +406,10 @@ def write_dig_points(fid, dig, block=False, coord_frame=None):
fid.write(np.array(d["kind"], ">i4").tobytes())
fid.write(np.array(d["ident"], ">i4").tobytes())
fid.write(np.array(d["r"][:3], ">f4").tobytes())
if ch_names is not None:
write_name_list_sanitized(
fid, FIFF.FIFF_MNE_CH_NAME_LIST, ch_names, "ch_names"
)
if block:
end_block(fid, FIFF.FIFFB_ISOTRAK)

Expand Down
48 changes: 36 additions & 12 deletions mne/channels/montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_on_missing,
_pl,
_validate_type,
check_fname,
copy_function_doc_to_method_doc,
fill_doc,
verbose,
Expand Down Expand Up @@ -409,9 +410,22 @@ def save(self, fname, *, overwrite=False, verbose=None):
The filename to use. Should end in .fif or .fif.gz.
%(overwrite)s
%(verbose)s
See Also
--------
mne.channels.read_dig_fif
Notes
-----
.. versionchanged:: 1.9
Added support for saving the associated channel names.
"""
fname = _check_fname(fname, overwrite=overwrite)
check_fname(fname, "montage", ("-dig.fif", "-dig.fif.gz"))
coord_frame = _check_get_coord_frame(self.dig)
write_dig(fname, self.dig, coord_frame, overwrite=overwrite)
write_dig(
fname, self.dig, coord_frame, overwrite=overwrite, ch_names=self.ch_names
)

def __iadd__(self, other):
"""Add two DigMontages in place.
Expand Down Expand Up @@ -808,17 +822,15 @@ def read_dig_dat(fname):
return make_dig_montage(electrodes, nasion, lpa, rpa)


def read_dig_fif(fname):
@verbose
def read_dig_fif(fname, *, verbose=None):
r"""Read digitized points from a .fif file.
Note that electrode names are not present in the .fif file so
they are here defined with the convention from VectorView
systems (EEG001, EEG002, etc.)
Parameters
----------
fname : path-like
FIF file from which to read digitization locations.
%(verbose)s
Returns
-------
Expand All @@ -835,17 +847,28 @@ def read_dig_fif(fname):
read_dig_hpts
read_dig_localite
make_dig_montage
Notes
-----
.. versionchanged:: 1.9
Added support for reading the associated channel names, if present.
In some files, electrode names are not present (e.g., in older files).
For those files, the channel names are defined with the convention from
VectorView systems (EEG001, EEG002, etc.).
"""
fname = _check_fname(fname, overwrite="read", must_exist=True)
check_fname(fname, "montage", ("-dig.fif", "-dig.fif.gz"))
fname = _check_fname(fname=fname, must_exist=True, overwrite="read")
# Load the dig data
f, tree = fiff_open(fname)[:2]
with f as fid:
dig = _read_dig_fif(fid, tree)
dig, ch_names = _read_dig_fif(fid, tree, return_ch_names=True)

ch_names = []
for d in dig:
if d["kind"] == FIFF.FIFFV_POINT_EEG:
ch_names.append(f"EEG{d['ident']:03d}")
if ch_names is None: # backward compat from when we didn't save the names
ch_names = []
for d in dig:
if d["kind"] == FIFF.FIFFV_POINT_EEG:
ch_names.append(f"EEG{d['ident']:03d}")

montage = DigMontage(dig=dig, ch_names=ch_names)
return montage
Expand Down Expand Up @@ -1572,6 +1595,7 @@ def read_custom_montage(
--------
make_dig_montage
make_standard_montage
read_dig_fif
Notes
-----
Expand Down
47 changes: 33 additions & 14 deletions mne/channels/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
assert_equal,
)

import mne.channels.montage
from mne import (
__file__ as _mne_file,
)
Expand Down Expand Up @@ -56,6 +57,7 @@
_BUILTIN_STANDARD_MONTAGES,
_check_get_coord_frame,
transform_to_head,
write_dig,
)
from mne.coreg import get_mni_fiducials
from mne.datasets import testing
Expand Down Expand Up @@ -138,7 +140,8 @@ def test_dig_montage_trans(tmp_path):
_ensure_trans(trans)
# ensure that we can save and load it, too
fname = tmp_path / "temp-mon.fif"
_check_roundtrip(montage, fname, "mri")
with pytest.warns(RuntimeWarning, match="MNE naming conventions"):
_check_roundtrip(montage, fname, "mri")
# test applying a trans
position1 = montage.get_positions()
montage.apply_trans(trans)
Expand Down Expand Up @@ -1074,12 +1077,12 @@ def _ensure_fid_not_nan(info, ch_pos):


@testing.requires_testing_data
def test_fif_dig_montage(tmp_path):
def test_fif_dig_montage(tmp_path, monkeypatch):
"""Test FIF dig montage support."""
dig_montage = read_dig_fif(fif_dig_montage_fname)
dig_montage = read_dig_fif(fif_dig_montage_fname, verbose="error")

# test round-trip IO
fname_temp = tmp_path / "test.fif"
fname_temp = tmp_path / "test-dig.fif"
_check_roundtrip(dig_montage, fname_temp)

# Make a BrainVision file like the one the user would have had
Expand Down Expand Up @@ -1119,16 +1122,32 @@ def test_fif_dig_montage(tmp_path):
# Roundtrip of non-FIF start
montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), hpi=read_mrk(hpi))
elp_points = read_polhemus_fastscan(elp)
ch_pos = {f"EEG{k:03d}": pos for k, pos in enumerate(elp_points[8:], 1)}
montage += make_dig_montage(
ch_pos = {f"ECoG{k:03d}": pos for k, pos in enumerate(elp_points[3:], 1)}
assert len(elp_points) == 8 # there are only 8 but pretend the last are ECoG
other = make_dig_montage(
nasion=elp_points[0], lpa=elp_points[1], rpa=elp_points[2], ch_pos=ch_pos
)
assert other.ch_names[0].startswith("ECoG")
montage += other
assert montage.ch_names[0].startswith("ECoG")
_check_roundtrip(montage, fname_temp, "unknown")
montage = transform_to_head(montage)
_check_roundtrip(montage, fname_temp)
montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_UNKNOWN
with pytest.raises(RuntimeError, match="Only a single coordinate"):
montage.save(fname_temp)
montage.save(fname_temp, overwrite=True)
montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_HEAD

# Check that old-style files can be read, too, using EEG001 etc.
def write_dig_no_ch_names(*args, **kwargs):
kwargs["ch_names"] = None
return write_dig(*args, **kwargs)

monkeypatch.setattr(mne.channels.montage, "write_dig", write_dig_no_ch_names)
montage.save(fname_temp, overwrite=True)
montage_read = read_dig_fif(fname_temp)
default_ch_names = [f"EEG{ii:03d}" for ii in range(1, 6)]
assert montage_read.ch_names == default_ch_names


@testing.requires_testing_data
Expand Down Expand Up @@ -1175,8 +1194,8 @@ def test_egi_dig_montage(tmp_path):
atol=1e-4,
)

# test round-trip IO
fname_temp = tmp_path / "egi_test.fif"
# test round-trip IO (with GZ)
fname_temp = tmp_path / "egi_test-dig.fif.gz"
_check_roundtrip(dig_montage, fname_temp, "unknown")
_check_roundtrip(dig_montage_in_head, fname_temp)

Expand Down Expand Up @@ -1330,7 +1349,7 @@ def test_read_dig_captrak(tmp_path):
)

montage = transform_to_head(montage) # transform_to_head has to be tested
_check_roundtrip(montage=montage, fname=str(tmp_path / "bvct_test.fif"))
_check_roundtrip(montage=montage, fname=tmp_path / "bvct_test-dig.fif")

fid, _ = _get_fid_coords(montage.dig)
assert_allclose(
Expand Down Expand Up @@ -1495,15 +1514,15 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads):
assert_array_less(0, ang) # but not equal


# XXX: this does not check ch_names + it cannot work because of write_dig
def _check_roundtrip(montage, fname, coord_frame="head"):
"""Check roundtrip writing."""
montage.save(fname, overwrite=True)
montage_read = read_dig_fif(fname=fname)

assert_equal(repr(montage), repr(montage_read))
assert_equal(_check_get_coord_frame(montage_read.dig), coord_frame)
assert repr(montage) == repr(montage_read)
assert _check_get_coord_frame(montage_read.dig) == coord_frame
assert_dig_allclose(montage, montage_read)
assert montage.ch_names == montage_read.ch_names


def test_digmontage_constructor_errors():
Expand Down Expand Up @@ -1910,7 +1929,7 @@ def test_get_montage():

# 4. read in BV test dataset and make sure montage
# fulfills roundtrip on non-standard montage
dig_montage = read_dig_fif(fif_dig_montage_fname)
dig_montage = read_dig_fif(fif_dig_montage_fname, verbose="error")

# Make a BrainVision file like the one the user would have had
# with testing dataset 'test.vhdr'
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_plot_montage():
assert "0 channels" in repr(montage)
with pytest.raises(RuntimeError, match="No valid channel positions"):
montage.plot()
d = read_dig_fif(fname=fif_fname)
d = read_dig_fif(fname=fif_fname, verbose="error")
assert "61 channels" in repr(d)
# XXX this is broken; dm.point_names is used. Sometimes we say this should
# Just contain the HPI coils, other times that it's all channels (e.g.,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/forward/35_eeg_no_mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
# in the MRI coordinate frame, which can be used to compute the
# MRI<->head transform ``trans``:
fname_1020 = subjects_dir / subject / "montages" / "10-20-montage.fif"
mon = mne.channels.read_dig_fif(fname_1020)
mon = mne.channels.read_dig_fif(fname_1020, verbose="error") # should be named -dig.fif
mon.rename_channels({f"EEG{ii:03d}": ch_name for ii, ch_name in enumerate(ch_names, 1)})
trans = mne.channels.compute_native_head_t(mon)
raw.set_montage(mon)
Expand Down

0 comments on commit 7071a0e

Please sign in to comment.