From ca5d3302ca4411903184089528baeabe498342e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:59:41 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/to_ibl.py | 36 ++------ src/spikeinterface/exporters/to_ibl_utils.py | 92 ++++++++++---------- 2 files changed, 57 insertions(+), 71 deletions(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index d6dbea3cc3..3515bfe0cf 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -148,30 +148,22 @@ def export_to_ibl( print("Running IBL-specific steps...") # Now we need to add the extra IBL specific files - (channel_inds,) = np.isin( - analyzer.recording.channel_ids, analyzer.channel_ids - ).nonzero() + (channel_inds,) = np.isin(analyzer.recording.channel_ids, analyzer.channel_ids).nonzero() ### Run spectral density and rms ### fs_ap = analyzer.recording.sampling_frequency rms_win_length_samples_ap = 2 ** np.ceil(np.log2(fs_ap * rms_win_length_sec)) - total_samples_ap = int( - np.min([fs_ap * total_secs, analyzer.recording.get_num_samples()]) - ) + total_samples_ap = int(np.min([fs_ap * total_secs, analyzer.recording.get_num_samples()])) # the window generator will generates window indices - wingen = WindowGenerator( - ns=total_samples_ap, nswin=rms_win_length_samples_ap, overlap=0 - ) + wingen = WindowGenerator(ns=total_samples_ap, nswin=rms_win_length_samples_ap, overlap=0) win = { "TRMS": np.zeros((wingen.nwin, analyzer.recording.get_num_channels())), "nsamples": np.zeros((wingen.nwin,)), "fscale": fscale(welch_win_length_samples, 1 / fs_ap, one_sided=True), "tscale": wingen.tscale(fs=fs_ap), } - win["spectral_density"] = np.zeros( - (len(win["fscale"]), analyzer.recording.get_num_channels()) - ) + win["spectral_density"] = np.zeros((len(win["fscale"]), analyzer.recording.get_num_channels())) # @Josh: this could be dramatically sped up if we employ SpikeInterface parallelization with tqdm(total=wingen.nwin) as pbar: @@ -213,17 +205,13 @@ def export_to_ibl( "rms": win["TRMS"].astype(np.single), "timestamps": win["tscale"].astype(np.single), } - save_object_npy( - output_folder, object=alf_object_time, dico=tdict, namespace="iblqc" - ) + save_object_npy(output_folder, object=alf_object_time, dico=tdict, namespace="iblqc") fdict = { "power": win["spectral_density"].astype(np.single), "freqs": win["fscale"].astype(np.single), } - save_object_npy( - output_folder, object=alf_object_freq, dico=fdict, namespace="iblqc" - ) + save_object_npy(output_folder, object=alf_object_freq, dico=fdict, namespace="iblqc") ### Save spike info ### @@ -236,9 +224,7 @@ def export_to_ibl( # convert times and squeeze times = np.load(output_folder / "spike_times.npy") - np.save( - output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64") - ) + np.save(output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64")) # convert amplitudes and squeeze amps = np.load(output_folder / "amplitudes.npy") @@ -246,9 +232,7 @@ def export_to_ibl( # save depths and channel inds np.save(output_folder / "spike_depths.npy", spike_depths) - np.save( - output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int") - ) + np.save(output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int")) # # save templates cluster_channels = [] @@ -261,9 +245,7 @@ def export_to_ibl( waveform = templates[unit_idx, :, :] extremum_channel_index = extremum_channel_indices[unit_id] peak_waveform = waveform[:, extremum_channel_index] - peakToTrough = ( - np.argmax(peak_waveform) - np.argmin(peak_waveform) - ) / analyzer.sampling_frequency + peakToTrough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / analyzer.sampling_frequency # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units cluster_channels.append( extremum_channel_index diff --git a/src/spikeinterface/exporters/to_ibl_utils.py b/src/spikeinterface/exporters/to_ibl_utils.py index e5db773332..35e694f5eb 100644 --- a/src/spikeinterface/exporters/to_ibl_utils.py +++ b/src/spikeinterface/exporters/to_ibl_utils.py @@ -4,6 +4,7 @@ Copied from https://github.com/int-brain-lab/ibl-neuropixel/ on 2/1/2024 """ + from math import pi import numpy as np @@ -12,7 +13,6 @@ import re - def _dromedary(string) -> str: """ Convert a string to camel case. Acronyms/initialisms are preserved. @@ -38,17 +38,19 @@ def _dromedary(string) -> str: -------- readableALF """ + def _capitalize(x): return x if x.isupper() else x.capitalize() + if not string: # short circuit on None and '' return string - first, *other = re.split(r'[_\s]', string) + first, *other = re.split(r"[_\s]", string) if len(other) == 0: # Already camel/Pascal case, ensure first letter lower case return first[0].lower() + first[1:] # Convert to camel case, preserving all-uppercase elements first = first if first.isupper() else first.casefold() - return ''.join([first, *map(_capitalize, other)]) + return "".join([first, *map(_capitalize, other)]) def to_alf(object, attribute, extension, namespace=None, timescale=None, extra=None): @@ -93,22 +95,21 @@ def to_alf(object, attribute, extension, namespace=None, timescale=None, extra=N """ # Validate inputs if not extension: - raise TypeError('An extension must be provided') - elif extension.startswith('.'): + raise TypeError("An extension must be provided") + elif extension.startswith("."): extension = extension[1:] - if any(pt is not None and '.' in pt for pt in - (object, attribute, namespace, extension, timescale)): - raise ValueError('ALF parts must not contain a period (`.`)') - if '_' in (namespace or ''): - raise ValueError('Namespace must not contain extra underscores') - if object[0] == '_': - raise ValueError('Objects must not contain underscores; use namespace arg instead') + if any(pt is not None and "." in pt for pt in (object, attribute, namespace, extension, timescale)): + raise ValueError("ALF parts must not contain a period (`.`)") + if "_" in (namespace or ""): + raise ValueError("Namespace must not contain extra underscores") + if object[0] == "_": + raise ValueError("Objects must not contain underscores; use namespace arg instead") # Ensure parts are camel case (converts whitespace and snake case) if timescale: timescale = filter(None, [timescale] if isinstance(timescale, str) else timescale) - timescale = '_'.join(map(_dromedary, timescale)) + timescale = "_".join(map(_dromedary, timescale)) # Convert attribute to camel case, leaving '_times', etc. in tact - times_re = re.search('_(times|timestamps|intervals)$', attribute) + times_re = re.search("_(times|timestamps|intervals)$", attribute) idx = times_re.start() if times_re else len(attribute) attribute = _dromedary(attribute[:idx]) + attribute[idx:] object = _dromedary(object) @@ -117,14 +118,16 @@ def to_alf(object, attribute, extension, namespace=None, timescale=None, extra=N if not extra: extra = () elif isinstance(extra, str): - extra = extra.split('.') + extra = extra.split(".") # Construct ALF file - parts = (('_%s_' % namespace if namespace else '') + object, - attribute + ('_%s' % timescale if timescale else ''), - *extra, - extension) - return '.'.join(parts) + parts = ( + ("_%s_" % namespace if namespace else "") + object, + attribute + ("_%s" % timescale if timescale else ""), + *extra, + extension, + ) + return ".".join(parts) def save_object_npy(alfpath, dico, object, parts=None, namespace=None, timescale=None) -> list: @@ -165,12 +168,13 @@ def save_object_npy(alfpath, dico, object, parts=None, namespace=None, timescale alfpath = Path(alfpath) status = check_dimensions(dico) if status != 0: - raise ValueError('Dimensions are not consistent to save all arrays in ALF format: ' + - str([(k, v.shape) for k, v in dico.items()])) + raise ValueError( + "Dimensions are not consistent to save all arrays in ALF format: " + + str([(k, v.shape) for k, v in dico.items()]) + ) out_files = [] for k, v in dico.items(): - out_file = alfpath / to_alf(object, k, 'npy', - extra=parts, namespace=namespace, timescale=timescale) + out_file = alfpath / to_alf(object, k, "npy", extra=parts, namespace=namespace, timescale=timescale) np.save(out_file, v) out_files.append(out_file) return out_files @@ -196,12 +200,10 @@ def check_dimensions(dico): """ # supported = (np.ndarray, pd.DataFrame) # idt any dataframes in this specific use case for SI supported = (np.ndarray,) # Data types that have a shape attribute - shapes = [dico[lab].shape for lab in dico - if isinstance(dico[lab], supported) and not lab.startswith('timestamps')] + shapes = [dico[lab].shape for lab in dico if isinstance(dico[lab], supported) and not lab.startswith("timestamps")] first_shapes = [sh[0] for sh in shapes] # Continuous timeseries are permitted to be a (2, 2) - timeseries = [k for k, v in dico.items() - if k.startswith('timestamps') and isinstance(v, np.ndarray)] + timeseries = [k for k, v in dico.items() if k.startswith("timestamps") and isinstance(v, np.ndarray)] if any(timeseries): for key in timeseries: if dico[key].ndim == 1 or (dico[key].ndim == 2 and dico[key].shape[1] == 1): @@ -222,7 +224,7 @@ def rms(x, axis=-1): :param axis: (optional, -1) :return: numpy array """ - return np.sqrt(np.mean(x ** 2, axis=axis)) + return np.sqrt(np.mean(x**2, axis=axis)) def _fcn_extrap(x, f, bounds): @@ -255,6 +257,7 @@ def fcn_cosine(bounds, gpu=False): def _cos(x): return (1 - gp.cos((x - bounds[0]) / (bounds[1] - bounds[0]) * gp.pi)) / 2 + func = lambda x: _fcn_extrap(x, _cos, bounds) # noqa return func @@ -285,7 +288,7 @@ def bp(ts, si, b, axis=None): :param axis: axis along which to perform reduction (last axis by default) :return: filtered time serie """ - return _freq_filter(ts, si, b, axis=axis, typ='bp') + return _freq_filter(ts, si, b, axis=axis, typ="bp") def lp(ts, si, b, axis=None): @@ -298,7 +301,7 @@ def lp(ts, si, b, axis=None): :param axis: axis along which to perform reduction (last axis by default) :return: filtered time serie """ - return _freq_filter(ts, si, b, axis=axis, typ='lp') + return _freq_filter(ts, si, b, axis=axis, typ="lp") def hp(ts, si, b, axis=None): @@ -311,19 +314,19 @@ def hp(ts, si, b, axis=None): :param axis: axis along which to perform reduction (last axis by default) :return: filtered time serie """ - return _freq_filter(ts, si, b, axis=axis, typ='hp') + return _freq_filter(ts, si, b, axis=axis, typ="hp") -def _freq_filter(ts, si, b, axis=None, typ='lp'): +def _freq_filter(ts, si, b, axis=None, typ="lp"): """ - Wrapper for hp/lp/bp filters + Wrapper for hp/lp/bp filters """ if axis is None: axis = ts.ndim - 1 ns = ts.shape[axis] f = fscale(ns, si=si, one_sided=True) - if typ == 'bp': - filc = _freq_vector(f, b[0:2], typ='hp') * _freq_vector(f, b[2:4], typ='lp') + if typ == "bp": + filc = _freq_vector(f, b[0:2], typ="hp") * _freq_vector(f, b[2:4], typ="lp") else: filc = _freq_vector(f, b, typ=typ) if axis < (ts.ndim - 1): @@ -331,21 +334,21 @@ def _freq_filter(ts, si, b, axis=None, typ='lp'): return np.real(np.fft.ifft(np.fft.fft(ts, axis=axis) * fexpand(filc, ns, axis=0), axis=axis)) -def _freq_vector(f, b, typ='lp'): +def _freq_vector(f, b, typ="lp"): """ - Returns a frequency modulated vector for filtering + Returns a frequency modulated vector for filtering - :param f: frequency vector, uniform and monotonic - :param b: 2 bounds array - :return: amplitude modulated frequency vector + :param f: frequency vector, uniform and monotonic + :param b: 2 bounds array + :return: amplitude modulated frequency vector """ filc = fcn_cosine(b)(f) - if typ.lower() in ['hp', 'highpass']: + if typ.lower() in ["hp", "highpass"]: return filc - elif typ.lower() in ['lp', 'lowpass']: + elif typ.lower() in ["lp", "lowpass"]: return 1 - filc - + def fexpand(x, ns=1, axis=None): """ Reconstructs full spectrum from positive frequencies @@ -373,6 +376,7 @@ class WindowGenerator(object): Example of implementations in test_dsp.py. """ + def __init__(self, ns, nswin, overlap): """ :param ns: number of sample of the signal along the direction to be windowed