Skip to content

Commit

Permalink
[ENH] Make SpatioSpectralFilter class more similar to `scikit-learn…
Browse files Browse the repository at this point in the history
…` fit-transform classes. (braindatalab#22)
  • Loading branch information
tsbinns authored Sep 16, 2024
1 parent ea80d34 commit fa10009
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 37 deletions.
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

##### Bug Fixes
- Fixed error where the number of subplots exceeding the number of nodes would cause plotting to fail.
- Fixed error where bandpass filter settings for the SSD method in `SpatioSpectralFilter` were not being applied correctly.

##### API
- Changed the default value of `min_ratio` in `SpatioSpectralFilter.get_transformed_data()` from `1.0` to `-inf`.
- Added the option to control whether a copy is returned from the `get_results()` method of all `Results...` classes and from `SpatioSpectralFilter.get_transformed_data()` (default behaviour returns a copy, like in previous versions).
- Added new `fit_ssd()`, `fit_hpmax()`, and `transform()` methods to the `SpatioSpectralFilter` class to bring it more in line with `scikit-learn` fit-transform classes.

##### Documentation
- Added a new example for computing the bispectrum and threenorm using the general classes.
Expand Down
8 changes: 5 additions & 3 deletions examples/plot_compute_waveshape_noisy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@

# perform spatio-spectral filtering
ssf = SpatioSpectralFilter(data=data, sampling_freq=sampling_freq, verbose=False)
ssf.fit_transform_hpmax(signal_bounds=(18, 22), noise_bounds=(15, 25), n_harmonics=2)
transformed_data = ssf.fit_transform_hpmax(
signal_bounds=(18, 22), noise_bounds=(15, 25), n_harmonics=2
)

# return the first component of the filtered data
transformed_data = (ssf.get_transformed_data()[:, 0])[:, np.newaxis, :]
# select the first component of the filtered data
transformed_data = transformed_data[:, [0], :]

print(
f"Original timeseries data: [{data.shape[0]} epochs x {data.shape[1]} channel(s) x "
Expand Down
4 changes: 2 additions & 2 deletions src/pybispectra/cfc/aac.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def _compute_aac(self) -> None:
)
except MemoryError as error: # pragma: no cover
raise MemoryError(
"Memory allocation for the bispectrum computation failed. Try reducing "
"the sampling frequency of the data, or reduce the precision of the "
"Memory allocation for the AAC computation failed. Try reducing the "
"sampling frequency of the data, or reduce the precision of the "
"computation with `pybispectra.set_precision('single')`."
) from error

Expand Down
2 changes: 1 addition & 1 deletion src/pybispectra/cfc/pac.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _compute_threenorm(self) -> None:
).transpose(1, 0, 2, 3)
except MemoryError as error: # pragma: no cover
raise MemoryError(
"Memory allocation for the bispectrum computation failed. Try reducing "
"Memory allocation for the threenorm computation failed. Try reducing "
"the sampling frequency of the data, or reduce the precision of the "
"computation with `pybispectra.set_precision('single')`."
) from error
Expand Down
4 changes: 2 additions & 2 deletions src/pybispectra/cfc/ppc.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def _compute_ppc(self) -> None:
)
except MemoryError as error: # pragma: no cover
raise MemoryError(
"Memory allocation for the bispectrum computation failed. Try reducing "
"the sampling frequency of the data, or reduce the precision of the "
"Memory allocation for the PPC computation failed. Try reducing the "
"sampling frequency of the data, or reduce the precision of the "
"computation with `pybispectra.set_precision('single')`."
) from error

Expand Down
2 changes: 1 addition & 1 deletion src/pybispectra/general/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _compute_threenorm(self) -> None:
).transpose(1, 0, 2, 3)
except MemoryError as error: # pragma: no cover
raise MemoryError(
"Memory allocation for the bispectrum computation failed. Try reducing "
"Memory allocation for the threenorm computation failed. Try reducing "
"the sampling frequency of the data, or reduce the precision of the "
"computation with `pybispectra.set_precision('single')`."
) from error
Expand Down
6 changes: 5 additions & 1 deletion src/pybispectra/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from mne import Info, create_info
from mne.parallel import parallel_func
from mne.utils import ProgressBar
from mne.utils import ProgressBar, set_log_level
from numba import njit

from pybispectra.utils._defaults import _precision
Expand Down Expand Up @@ -66,11 +66,15 @@ def _compute_in_parallel(
parallel, my_parallel_func, _ = parallel_func(
func, n_jobs, prefer=prefer, verbose=verbose
)
old_log_level = set_log_level(
verbose="INFO" if verbose else "WARNING", return_old_level=True
) # need to set log level that is passed to tqdm
for block_i in ProgressBar(range(n_blocks), mesg=message):
idcs = _get_block_indices(block_i, n_steps, n_jobs)
output[idcs] = parallel(
my_parallel_func(**loop_kwargs[idx], **static_kwargs) for idx in idcs
)
set_log_level(verbose=old_log_level) # reset log level

return output

Expand Down
155 changes: 138 additions & 17 deletions src/pybispectra/utils/ged.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SpatioSpectralFilter:
Parameters
----------
data : ~numpy.ndarray, shape of [epochs, channels, times]
Data to perform spatiospectral filtering on.
sampling_freq : int | float
Sampling frequency (in Hz) of :attr:`data`.
Expand All @@ -29,6 +30,15 @@ class SpatioSpectralFilter:
Methods
-------
fit_hpmax :
Fit HPMax filters to the data.
fit_ssd :
Fit SSD filters to the data.
transform :
Transform the data with the fitted filters.
fit_transform_hpmax :
Fit HPMax filters and transform the data.
Expand Down Expand Up @@ -141,9 +151,12 @@ class SpatioSpectralFilter:
filters = None
patterns = None
ratios = None
_ssd = None
_transformed_data = None

_fitted = False
_fitted_method = None
_transformed = False

def __init__(
self,
Expand Down Expand Up @@ -214,7 +227,7 @@ def _sort_bandpass_filter(self, bandpass_filter: bool) -> None:
if not isinstance(bandpass_filter, bool):
raise TypeError("`bandpass_filter` must be a bool.")

self.bandpass_filter = True
self.bandpass_filter = bandpass_filter

def _sort_n_harmonics(self, n_harmonics: int) -> None:
"""Sort harmonic use input."""
Expand Down Expand Up @@ -283,7 +296,7 @@ def _sort_csd_method(self, csd_method: str) -> None:
if csd_method not in accepted_methods:
raise ValueError("`csd_method` is not recognised.")

def fit_transform_ssd(
def fit_ssd(
self,
signal_bounds: tuple[int | float],
noise_bounds: tuple[int | float],
Expand All @@ -292,7 +305,7 @@ def fit_transform_ssd(
indices: tuple[int] | None = None,
rank: int | None = None,
) -> None:
"""Fit SSD filters and transform the data.
"""Fit SSD filters to the data.
Parameters
----------
Expand Down Expand Up @@ -325,6 +338,8 @@ def fit_transform_ssd(
-----
The SSD implementation in MNE is used to compute the filters
(:class:`mne.decoding.SSD`).
.. versionadded:: 1.2
"""
self._sort_freq_bounds(signal_bounds, noise_bounds, signal_noise_gap)
self._sort_bandpass_filter(bandpass_filter)
Expand All @@ -342,6 +357,7 @@ def fit_transform_ssd(
self._compute_ssd(info, filt_params_signal, filt_params_noise)

self._fitted = True
self._fitted_method = "SSD"

if self.verbose:
print(" ... SSD filter fitting finished\n")
Expand Down Expand Up @@ -397,7 +413,7 @@ def _compute_ssd(
"PyBispectra Internal Error: channel types in `info` should all be 'eeg'. "
"Please contact the PyBispectra developers."
)
ssd = SSD(
self._ssd = SSD(
info,
filt_params_signal,
filt_params_noise,
Expand All @@ -408,13 +424,13 @@ def _compute_ssd(
return_filtered=self.bandpass_filter,
rank={"eeg": self.rank},
)
self._transformed_data = ssd.fit_transform(self.data[:, self.indices])
self._ssd.fit(self.data[:, self.indices])

self.filters = ssd.filters_
self.patterns = ssd.patterns_
self.ratios = ssd.eigvals_
self.filters = self._ssd.filters_
self.patterns = self._ssd.patterns_
self.ratios = self._ssd.eigvals_

def fit_transform_hpmax(
def fit_hpmax(
self,
signal_bounds: tuple[int | float],
noise_bounds: tuple[int | float],
Expand All @@ -428,7 +444,7 @@ def fit_transform_hpmax(
mt_low_bias: bool = True,
n_jobs: int = 1,
) -> None:
"""Fit HPMax filters and transform the data.
"""Fit HPMax filters to the data.
Parameters
----------
Expand Down Expand Up @@ -484,6 +500,8 @@ def fit_transform_hpmax(
MNE is used to compute the CSD, from which the covariance matrices are obtained
:footcite:`Bartz2019` (:func:`mne.time_frequency.csd_array_multitaper` and
:func:`mne.time_frequency.csd_array_fourier`).
.. versionadded:: 1.2
"""
self._sort_freq_bounds(signal_bounds, noise_bounds, 0.0)
self._sort_n_harmonics(n_harmonics)
Expand All @@ -506,6 +524,7 @@ def fit_transform_hpmax(
self._compute_hpmax(csd, freqs)

self._fitted = True
self._fitted_method = "HPMax"

if self.verbose:
print(" ... HPMax filter fitting finished\n")
Expand Down Expand Up @@ -613,13 +632,6 @@ def _compute_hpmax(self, csd: np.ndarray, freqs: np.ndarray) -> None:
self.patterns = np.linalg.pinv(self.filters).astype(_precision.real)
self.ratios = eigvals[eig_idx].astype(_precision.real)

self._transformed_data = np.einsum(
"ijk,jl->ilk",
self.data[:, self.indices],
self.filters,
dtype=_precision.real,
)

if self.verbose:
print(" ... HPMax filter computation finished\n")

Expand Down Expand Up @@ -694,6 +706,108 @@ def _project_cov_rank_subspace(

return cov_signal, cov_noise, projection

def transform(self, data: np.ndarray | None = None) -> np.ndarray:
"""Transform the data with the fitted filters.
Parameters
----------
data : ~numpy.ndarray, shape of [epochs, channels, times] | None (default None)
Data to transform with the fitted filters. If :obj:`None`, the data used to
fit the filters is transformed.
Returns
-------
transformed_data : ~numpy.ndarray, shape of [epochs, components, times]
Transformed data.
Notes
-----
.. versionadded:: 1.2
"""
if not self._fitted:
raise ValueError(
"No filters have been fit. Please call `fit_ssd` or `fit_hpmax` before "
"transforming the data."
)

if data is None:
data = self.data
if not isinstance(data, np.ndarray):
raise TypeError("`data` must be a NumPy array.")
if data.ndim != 3:
raise ValueError("`data` must be a 3D array.")
if data.shape[1] != self.filters.shape[0]:
raise ValueError(
"`data` must have the same number of channels as the filters."
)

if self.verbose:
print("Transforming data with filters...\n")

if self.bandpass_filter and self._fitted_method == "SSD":
self._transformed_data = self._ssd.transform(data)
else:
self._transformed_data = np.einsum(
"ijk,jl->ilk",
data[:, self.indices],
self.filters,
dtype=_precision.real,
)

if self.verbose:
print(" ... Data transformation finished\n")

self._transformed = True

return self._transformed_data

def fit_transform_ssd(self, *args: tuple, **kwargs: dict) -> np.ndarray:
"""Fit SSD filters and transform the data.
Parameters
----------
args : tuple
Positional parameters to pass to :meth:`fit_ssd`.
kwargs : dict
Keyword parameters to pass to :meth:`fit_ssd`.
Returns
-------
transformed_data : ~numpy.ndarray, shape of [epochs, components, times]
Transformed data.
Notes
-----
Equivalent to calling :meth:`fit_ssd` followed by :meth:`transform`.
"""
self.fit_ssd(*args, **kwargs)
return self.transform()

def fit_transform_hpmax(self, *args: tuple, **kwargs: dict) -> np.ndarray:
"""Fit HPMax filters and transform the data.
Parameters
----------
args : tuple
Positional parameters to pass to :meth:`fit_hpmax`.
kwargs : dict
Keyword parameters to pass to :meth:`fit_hpmax`.
Returns
-------
transformed_data : ~numpy.ndarray, shape of [epochs, components, times]
Transformed data.
Notes
-----
Equivalent to calling :meth:`fit_hpmax` followed by :meth:`transform`.
"""
self.fit_hpmax(*args, **kwargs)
return self.transform()

def get_transformed_data(
self, min_ratio: int | float = -np.inf, copy: bool = True
) -> np.ndarray:
Expand Down Expand Up @@ -724,6 +838,13 @@ def get_transformed_data(
Raises a warning if no components have a signal-to-noise ratio > ``min_ratio``
and :attr:`verbose` is :obj:`True`.
"""
if not self._transformed:
raise ValueError(
"No data has been transformed. Please call `transform`, "
"`fit_transform_ssd`, or `fit_transform_hpmax` before getting the "
"transformed data."
)

if not isinstance(min_ratio, (int, float)):
raise TypeError("`min_ratio` must be an int or a float")
if not isinstance(copy, bool):
Expand Down
4 changes: 2 additions & 2 deletions src/pybispectra/waveshape/waveshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _compute_bispectrum(self) -> None:
(self._n_cons, 1, self._f1s.size, self._f2s.size),
dtype=_precision.complex,
),
message="Processing connections...",
message="Processing channels...",
n_jobs=self._n_jobs,
verbose=self.verbose,
prefer="processes",
Expand Down Expand Up @@ -234,7 +234,7 @@ def _compute_threenorm(self) -> None:
(self._n_cons, 1, self._f1s.size, self._f2s.size),
dtype=_precision.real,
),
message="Processing connections...",
message="Processing channels...",
n_jobs=self._n_jobs,
verbose=self.verbose,
prefer="processes",
Expand Down
Loading

0 comments on commit fa10009

Please sign in to comment.