Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 30, 2024
1 parent 824457d commit 5e5f16b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
7 changes: 4 additions & 3 deletions examples/preprocessing/interpolate_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
# Authors: Antoine Collas <[email protected]>
# License: BSD-3-Clause

import matplotlib.pyplot as plt

import mne
from mne.datasets import sample
from mne.channels import make_standard_montage
import matplotlib.pyplot as plt
from mne.datasets import sample

print(__doc__)

Expand All @@ -39,7 +40,7 @@

# %%
# Define the target montage
standard_montage = make_standard_montage('standard_1020')
standard_montage = make_standard_montage("standard_1020")

# %%
# Use interpolate_to to project EEG data to the standard montage
Expand Down
41 changes: 25 additions & 16 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def interpolate_bads(

return self

def interpolate_to(self, montage, method='MNE', reg=0.0):
def interpolate_to(self, montage, method="MNE", reg=0.0):
"""Interpolate data onto a new montage.
Parameters
Expand All @@ -968,6 +968,7 @@ def interpolate_to(self, montage, method='MNE', reg=0.0):
The instance with updated channel locations and data.
"""
import numpy as np

import mne
from mne import pick_types
from mne.forward._field_interpolation import _map_meg_or_eeg_channels
Expand All @@ -980,58 +981,66 @@ def interpolate_to(self, montage, method='MNE', reg=0.0):
if len(picks_from) == 0:
raise ValueError("No EEG channels available for interpolation.")

if hasattr(self, '_data'):
if hasattr(self, "_data"):
data_orig = self._data[picks_from]
else:
# If epochs-like data, for simplicity take the mean across epochs
data_orig = self.get_data()[:, picks_from, :].mean(axis=0)

# Get target positions from the montage
ch_pos = montage.get_positions()['ch_pos']
ch_pos = montage.get_positions()["ch_pos"]
target_ch_names = list(ch_pos.keys())
if len(target_ch_names) == 0:
raise ValueError("The provided montage does not contain any channel positions.")
raise ValueError(
"The provided montage does not contain any channel positions."
)

# Create a new info structure using MNE public API
sfreq = self.info['sfreq']
ch_types = ['eeg'] * len(target_ch_names)
new_info = mne.create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types)
sfreq = self.info["sfreq"]
ch_types = ["eeg"] * len(target_ch_names)
new_info = mne.create_info(
ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types
)
new_info.set_montage(montage)

# Create a simple old_info
sfreq = self.info['sfreq']
ch_names = self.info['ch_names']
ch_types = ['eeg'] * len(ch_names)
sfreq = self.info["sfreq"]
ch_names = self.info["ch_names"]
ch_types = ["eeg"] * len(ch_names)
old_info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
old_info.set_montage(self.info.get_montage())

# Compute mapping from current montage to target montage
mapping = _map_meg_or_eeg_channels(old_info, new_info, mode='accurate', origin='auto')
mapping = _map_meg_or_eeg_channels(
old_info, new_info, mode="accurate", origin="auto"
)

# Apply the interpolation mapping
D_new = mapping.dot(data_orig)

# Update bad channels
new_bads = [ch for ch in self.info['bads'] if ch in target_ch_names]
new_info['bads'] = new_bads
new_bads = [ch for ch in self.info["bads"] if ch in target_ch_names]
new_info["bads"] = new_bads

# Update the instance's info and data
self.info = new_info
if hasattr(self, '_data'):
if hasattr(self, "_data"):
if self._data.ndim == 2:
# Raw-like: directly assign the new data
self._data = D_new
else:
# Epochs-like
n_epochs, _, n_times = self._data.shape
new_data = np.zeros((n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype)
new_data = np.zeros(
(n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype
)
for e in range(n_epochs):
epoch_data_orig = self._data[e, picks_from, :]
new_data[e, :, :] = mapping.dot(epoch_data_orig)
self._data = new_data
else:
# Evoked-like data
if hasattr(self, 'data'):
if hasattr(self, "data"):
self.data = D_new
else:
raise NotImplementedError("This method requires preloaded data.")
Expand Down

0 comments on commit 5e5f16b

Please sign in to comment.