diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 1b28be9752..8f3729b49b 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -24,6 +24,8 @@ CenterRecording, center, ) +from .scale import scale_to_uV + from .whiten import WhitenRecording, whiten, compute_whitening_matrix from .rectify import RectifyRecording, rectify from .clip import BlankSaturationRecording, blank_staturation, ClipRecording, clip diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py new file mode 100644 index 0000000000..bc77577ce0 --- /dev/null +++ b/src/spikeinterface/preprocessing/scale.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.core import BaseRecording +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor + + +def scale_to_uV(recording: BasePreprocessor) -> BasePreprocessor: + """ + Scale raw traces to microvolts (µV). + + This preprocessor uses the channel-specific gain and offset information + stored in the recording extractor to convert the raw traces to µV units. + + Parameters + ---------- + recording : BaseRecording + The recording extractor to be scaled. The recording extractor must + have gains and offsets otherwise an error will be raised. + + Raises + ------ + AssertionError + If the recording extractor does not have scaleable traces. + """ + # To avoid a circular import + from spikeinterface.preprocessing import ScaleRecording + + if not recording.has_scaleable_traces(): + error_msg = "Recording must have gains and offsets set to be scaled to µV" + raise RuntimeError(error_msg) + + gain = recording.get_channel_gains() + offset = recording.get_channel_offsets() + + scaled_to_uV_recording = ScaleRecording(recording, gain=gain, offset=offset, dtype="float32") + + # We do this so when get_traces(return_scaled=True) is called, the return is the same. + scaled_to_uV_recording.set_channel_gains(gains=1.0) + scaled_to_uV_recording.set_channel_offsets(offsets=0.0) + + return scaled_to_uV_recording diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py new file mode 100644 index 0000000000..321d7c9df2 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -0,0 +1,70 @@ +import pytest +import numpy as np +from spikeinterface.core.testing_tools import generate_recording +from spikeinterface.preprocessing import scale_to_uV, CenterRecording + + +def test_scale_to_uV(): + # Create a sample recording extractor with fake gains and offsets + num_channels = 4 + sampling_frequency = 30_000.0 + durations = [1.0, 1.0] # seconds + recording = generate_recording( + num_channels=num_channels, + durations=durations, + sampling_frequency=sampling_frequency, + ) + + rng = np.random.default_rng(0) + gains = rng.random(size=(num_channels)).astype(np.float32) + offsets = rng.random(size=(num_channels)).astype(np.float32) + recording.set_channel_gains(gains) + recording.set_channel_offsets(offsets) + + # Apply the preprocessor + scaled_recording = scale_to_uV(recording=recording) + + # Check if the traces are indeed scaled + expected_traces = recording.get_traces(return_scaled=True, segment_index=0) + scaled_traces = scaled_recording.get_traces(segment_index=0) + + np.testing.assert_allclose(scaled_traces, expected_traces) + + # Test for the error when recording doesn't have scaleable traces + recording.set_channel_gains(None) # Remove gains to make traces unscaleable + with pytest.raises(RuntimeError): + scale_to_uV(recording) + + +def test_scaling_in_preprocessing_chain(): + + # Create a sample recording extractor with fake gains and offsets + num_channels = 4 + sampling_frequency = 30_000.0 + durations = [1.0] # seconds + recording = generate_recording( + num_channels=num_channels, + durations=durations, + sampling_frequency=sampling_frequency, + ) + + rng = np.random.default_rng(0) + gains = rng.random(size=(num_channels)).astype(np.float32) + offsets = rng.random(size=(num_channels)).astype(np.float32) + + recording.set_channel_gains(gains) + recording.set_channel_offsets(offsets) + + centered_recording = CenterRecording(scale_to_uV(recording=recording)) + traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) + + # Chain preprocessors + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() + + np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) + + # Test if the scaling is not done twice + traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) + + np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument)