Skip to content

Commit

Permalink
Merge pull request #2431 from cwindolf/normalize_scale_pickle_fix
Browse files Browse the repository at this point in the history
Faster unpickling of ZScoreRecording
  • Loading branch information
alejoe91 authored Jan 24, 2024
2 parents fe0cbf6 + 999dbe8 commit c2f18bd
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions src/spikeinterface/preprocessing/normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,6 @@ def __init__(
if dtype_.kind == "i":
assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a scale"

random_data = get_random_data_chunks(recording, **random_chunk_kwargs)

if gain is not None:
assert offset is not None
gain = np.asarray(gain)
Expand All @@ -285,20 +283,23 @@ def __init__(
if offset.ndim == 1:
offset = offset[None, :]
assert offset.shape[1] == n
elif mode == "median+mad":
medians = np.median(random_data, axis=0)
medians = medians[None, :]
mads = np.median(np.abs(random_data - medians), axis=0) / 0.6744897501960817
mads = mads[None, :]
gain = 1 / mads
offset = -medians / mads
else:
means = np.mean(random_data, axis=0)
means = means[None, :]
stds = np.std(random_data, axis=0)
stds = stds[None, :]
gain = 1.0 / stds
offset = -means / stds
random_data = get_random_data_chunks(recording, **random_chunk_kwargs)

if mode == "median+mad":
medians = np.median(random_data, axis=0)
medians = medians[None, :]
mads = np.median(np.abs(random_data - medians), axis=0) / 0.6744897501960817
mads = mads[None, :]
gain = 1 / mads
offset = -medians / mads
else:
means = np.mean(random_data, axis=0)
means = means[None, :]
stds = np.std(random_data, axis=0)
stds = stds[None, :]
gain = 1.0 / stds
offset = -means / stds

if int_scale is not None:
gain *= int_scale
Expand Down

0 comments on commit c2f18bd

Please sign in to comment.