diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..5af20d79b5 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~np.isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~np.isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..20ee7910b4 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..08f187895b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..f35bc2b266 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index eb141abde4..e6d08d38f7 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..07837bcef7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -166,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~np.isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a6b94c9b84..75182bf532 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(np.isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 877c9fb00c..6881ab3ec5 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -4,6 +4,7 @@ import shutil from typing import Iterable, Literal, Optional import json +import os import numpy as np from copy import deepcopy @@ -87,6 +88,7 @@ def __init__( self._template_cache = {} self._params = {} self._loaded_extensions = dict() + self._is_read_only = False self.sparsity = sparsity self.folder = folder @@ -103,6 +105,8 @@ def __init__( if (self.folder / "params.json").is_file(): with open(str(self.folder / "params.json"), "r") as f: self._params = json.load(f) + if not os.access(self.folder, os.W_OK): + self._is_read_only = True else: # this is in case of in-memory self.format = "memory" @@ -399,6 +403,9 @@ def return_scaled(self) -> bool: def dtype(self): return self._params["dtype"] + def is_read_only(self) -> bool: + return self._is_read_only + def has_recording(self) -> bool: return self._recording is not None @@ -516,6 +523,10 @@ def is_extension(self, extension_name) -> bool: """ if self.folder is None: return extension_name in self._loaded_extensions + + if extension_name in self._loaded_extensions: + # extension already loaded in memory + return True else: if self.format == "binary": return (self.folder / extension_name).is_dir() and ( @@ -1740,13 +1751,33 @@ def __init__(self, waveform_extractor): if self.format == "binary": self.extension_folder = self.folder / self.extension_name if not self.extension_folder.is_dir(): - self.extension_folder.mkdir() + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_folder.mkdir() + else: import zarr - zarr_root = zarr.open(self.folder, mode="r+") + mode = "r+" if not self.waveform_extractor.is_read_only() else "r" + zarr_root = zarr.open(self.folder, mode=mode) if self.extension_name not in zarr_root.keys(): - self.extension_group = zarr_root.create_group(self.extension_name) + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_group = zarr_root.create_group(self.extension_name) else: self.extension_group = zarr_root[self.extension_name] else: @@ -1863,6 +1894,9 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): + # Only save if not read only + if self.waveform_extractor.is_read_only(): + return if self.format == "binary": import pandas as pd @@ -1900,7 +1934,9 @@ def _save(self, **kwargs): self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif isinstance(ext_data, pd.DataFrame): ext_data.to_xarray().to_zarr( - store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a" + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", ) self.extension_group[ext_data_name].attrs["dataframe"] = True else: diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..5295cc76d8 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): ---------- parent_sorting: Recording The sorting object - units_to_merge: list of lists + units_to_merge: list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids: None or list @@ -24,6 +24,7 @@ class MergeUnitsSorting(BaseSorting): Default: 'keep' delta_time_ms: float or None Number of ms to consider for duplicated spikes. None won't check for duplications + Returns ------- sorting: Sorting @@ -33,7 +34,7 @@ class MergeUnitsSorting(BaseSorting): def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): self._parent_sorting = parent_sorting - if not isinstance(units_to_merge[0], list): + if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge units_to_merge = [units_to_merge] @@ -59,7 +60,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +69,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 5615402fdb..c92861a8bf 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -178,7 +178,11 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if waveform_extractor.is_extension("similarity"): + tmc = waveform_extractor.load_extension("similarity") + template_similarity = tmc.get_data() + else: + template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") np.save(str(output_folder / "templates.npy"), templates) np.save(str(output_folder / "template_ind.npy"), templates_ind) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..8b70722652 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = np.isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index d4446e2289..22b40a51c5 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index fd6078b9b0..38cb714d59 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..4cbe4d665e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..f7272ddefe 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,9 +2,10 @@ import numpy as np import pandas as pd import shutil +import platform from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, compute_sparsity +from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity from spikeinterface.extractors import toy_example if hasattr(pytest, "global_test_folder"): @@ -76,6 +77,16 @@ def setUp(self): overwrite=True, ) self.we2 = we2 + + # make we read-only + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) + # change permissions (R+X) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) + self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( recording, @@ -97,6 +108,12 @@ def setUp(self): folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True ) + def tearDown(self): + # allow pytest to delete RO folder + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) + def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: extension_function_kwargs_list = [dict()] @@ -177,3 +194,11 @@ def test_extension(self): assert ext_data_mem.equals(ext_data_zarr) else: print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") + + # read-only - Extension is memory only + if platform.system() != "Windows": + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 740fdd234b..d2739f69dd 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): + import sklearn.metrics + x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights( template_positions[:, 0] = all_x.flatten() template_positions[:, 1] = all_y.flatten() - import sklearn - # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) nearest_template_mask = dist < radius_um diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..95ecd0fe52 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 01701e4f65..8dd5f857f6 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 59000211d4..ed06f7d738 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, np.isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, np.isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..772c99bc0a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..73497a59fd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..23fdbf1979 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..08ce9f6791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(np.isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(np.isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..b3391c0712 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = np.isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_