From 042d6384519f3f6ec46a1ce9ce960d19096c52ae Mon Sep 17 00:00:00 2001 From: timonmerk Date: Sat, 23 Nov 2024 11:49:38 +0100 Subject: [PATCH] add no psd feature normalization option to settings --- py_neuromodulation/default_settings.yaml | 1 + .../processing/data_preprocessor.py | 7 ++++ .../processing/normalization.py | 33 ++++++++++++++----- py_neuromodulation/stream/data_processor.py | 27 +++++++++++++-- py_neuromodulation/stream/settings.py | 4 +-- 5 files changed, 58 insertions(+), 14 deletions(-) diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml index 49988840..c50b5f8e 100644 --- a/py_neuromodulation/default_settings.yaml +++ b/py_neuromodulation/default_settings.yaml @@ -71,6 +71,7 @@ postprocessing: feature_normalization_settings: normalization_time_s: 30 normalization_method: zscore # supported methods: mean, median, zscore, zscore-median, quantile, power, robust, minmax + normalize_psd: false clip: 3 project_cortex_settings: diff --git a/py_neuromodulation/processing/data_preprocessor.py b/py_neuromodulation/processing/data_preprocessor.py index 2946ba6c..319926d9 100644 --- a/py_neuromodulation/processing/data_preprocessor.py +++ b/py_neuromodulation/processing/data_preprocessor.py @@ -72,6 +72,13 @@ def instantiate_preprocessor( ] def process_data(self, data: "np.ndarray") -> "np.ndarray": + """ + Args: + data (np.ndarray): shape: (n_channels, n_samples) + + Returns: + np.ndarray: shape: (n_channels, n_samples) + """ for preprocessor in self.preprocessors: data = preprocessor.process(data) return data diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py index 3345b122..e3f25d55 100644 --- a/py_neuromodulation/processing/normalization.py +++ b/py_neuromodulation/processing/normalization.py @@ -25,6 +25,7 @@ class NormalizationSettings(NMBaseModel): def list_normalization_methods() -> list[NormMethod]: return list(get_args(NormMethod)) +class FeatureNormalizationSettings(NormalizationSettings): normalize_psd: bool = False class Normalizer(NMPreprocessor): def __init__( @@ -32,9 +33,13 @@ def __init__( sfreq: float, settings: "NMSettings", type: NormalizerType, + **kwargs, ) -> None: self.type = type - self.settings: NormalizationSettings + if self.type == "raw": + self.settings: NormalizationSettings + else: + self.settings: FeatureNormalizationSettings match self.type: case "raw": @@ -74,14 +79,24 @@ def __init__( self.normalizer = NORM_FUNCTIONS[self.method] def process(self, data: np.ndarray) -> np.ndarray: - # TODO: does feature normalization need to be transposed too? - if self.type == "raw": - data = data.T + """Process normalization. + Note: raw data has to be internally transposed, s.t. raw and features + are normalized in the same way. + Args: + data (np.ndarray): shape (channels, n_samples) + + Returns: + np.ndarray: (channels, n_samples) + """ + if self.previous.size == 0: # Check if empty self.previous = data - return data if self.type == "raw" else data.T - + if self.type == "raw": + self.previous = self.previous.T + return data + if self.type == "raw": + data = data.T self.previous = np.vstack((self.previous, data[-self.add_samples :])) data = self.normalizer(data, self.previous) @@ -93,12 +108,12 @@ def process(self, data: np.ndarray) -> np.ndarray: data = np.nan_to_num(data) - return data if self.type == "raw" else data.T + return data if self.type != "raw" else data.T class RawNormalizer(Normalizer): - def __init__(self, sfreq: float, settings: "NMSettings") -> None: - super().__init__(sfreq, settings, "raw") + def __init__(self, sfreq: float, settings: "NMSettings", **kwargs,) -> None: + super().__init__(sfreq, settings, "raw", **kwargs) class FeatureNormalizer(Normalizer): diff --git a/py_neuromodulation/stream/data_processor.py b/py_neuromodulation/stream/data_processor.py index 5fea5ce9..966b5ff7 100644 --- a/py_neuromodulation/stream/data_processor.py +++ b/py_neuromodulation/stream/data_processor.py @@ -55,6 +55,7 @@ def __init__( self.sfreq_raw: float = sfreq // 1 self.line_noise: float | None = line_noise self.path_grids: _PathLike | None = path_grids + self.non_psd_indices: np.ndarray | None = None self.verbose: bool = verbose self.features_previous = None @@ -255,9 +256,29 @@ def process(self, data: np.ndarray) -> dict[str, float]: # normalize features if self.settings.postprocessing.feature_normalization: - normed_features = self.feature_normalizer.process( - np.fromiter(features_dict.values(), dtype=np.float64) - ) + if not self.settings.feature_normalization_settings.normalize_psd: + if self.non_psd_indices is None: + self.non_psd_indices = [ + idx + for idx, key in enumerate(features_dict.keys()) + if "psd" not in key + ] + self.psd_indices = list(set(range(len(features_dict))) - set( + self.non_psd_indices + )) + feature_values = np.fromiter(features_dict.values(), dtype=np.float64) + normed_features_non_psd = self.feature_normalizer.process( + feature_values[self.non_psd_indices] + ) + + # combine values in new array + normed_features = np.empty((feature_values.shape[0])) + normed_features[self.non_psd_indices] = normed_features_non_psd + normed_features[self.psd_indices] = feature_values[self.psd_indices] + else: + normed_features = self.feature_normalizer.process( + np.fromiter(features_dict.values(), dtype=np.float64) + ) features_dict = { key: normed_features[idx] for idx, key in enumerate(features_dict.keys()) diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py index 7e25ec24..669c97db 100644 --- a/py_neuromodulation/stream/settings.py +++ b/py_neuromodulation/stream/settings.py @@ -16,7 +16,7 @@ ) from py_neuromodulation.processing.filter_preprocessing import FilterSettings -from py_neuromodulation.processing.normalization import NormalizationSettings +from py_neuromodulation.processing.normalization import FeatureNormalizationSettings, NormalizationSettings from py_neuromodulation.processing.resample import ResamplerSettings from py_neuromodulation.processing.projection import ProjectionSettings @@ -83,7 +83,7 @@ class NMSettings(NMBaseModel): # Postprocessing settings postprocessing: PostprocessingSettings = PostprocessingSettings() - feature_normalization_settings: NormalizationSettings = NormalizationSettings() + feature_normalization_settings: FeatureNormalizationSettings = FeatureNormalizationSettings() project_cortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=20) project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5)