diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 90b39aee8a..7d43982853 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -293,7 +293,7 @@ def __init__( means = means[None, :] stds = np.std(random_data, axis=0) stds = stds[None, :] - gain = 1 / stds + gain = 1.0 / stds offset = -means / stds if int_scale is not None: diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index b62a73a8cb..764acc9852 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -78,13 +78,18 @@ def test_zscore(): assert np.all(np.abs(np.mean(tr, axis=0)) < 0.01) assert np.all(np.abs(np.std(tr, axis=0) - 1) < 0.01) + +def test_zscore_int(): + seed = 1 + rec = generate_recording(seed=seed, mode="legacy") rec_int = scale(rec, dtype="int16", gain=100) with pytest.raises(AssertionError): - rec4 = zscore(rec_int, dtype=None) - rec4 = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) - tr = rec4.get_traces(segment_index=0) - trace_mean = np.mean(tr, axis=0) - trace_std = np.std(tr, axis=0) + zscore(rec_int, dtype=None) + + zscore_recording = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) + traces = zscore_recording.get_traces(segment_index=0) + trace_mean = np.mean(traces, axis=0) + trace_std = np.std(traces, axis=0) assert np.all(np.abs(trace_mean) < 1) assert np.all(np.abs(trace_std - 256) < 1)