Skip to content

Commit

Permalink
in1d to isin with correct alias (shame on me)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 15, 2023
1 parent e947e09 commit 5e420f3
Show file tree
Hide file tree
Showing 20 changed files with 42 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/comparison/basecomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self):
indexes = np.arange(scores.shape[1])
order1 = []
for r in range(scores.shape[0]):
possible = indexes[~np.in1d(indexes, order1)]
possible = indexes[~np.isin(indexes, order1)]
if possible.size > 0:
ind = np.argmax(scores.iloc[r, possible].values)
order1.append(possible[ind])
remain = indexes[~np.in1d(indexes, order1)]
remain = indexes[~np.isin(indexes, order1)]
order1.extend(remain)
scores = scores.iloc[:, order1]

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun
matched_units2 = match_12[match_12 != -1].values

unmatched_units1 = match_12[match_12 == -1].index
unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)]
unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)]

ordered_units1 = np.hstack([matched_units1, unmatched_units1])
ordered_units2 = np.hstack([matched_units2, unmatched_units2])
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None):
def _remove_channels(self, remove_channel_ids):
from .channelslice import ChannelSliceRecording

new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)]
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)]
sub_recording = ChannelSliceRecording(self, new_channel_ids)
return sub_recording

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None):
def _remove_channels(self, remove_channel_ids):
from .channelslice import ChannelSliceSnippets

new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)]
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)]
sub_recording = ChannelSliceSnippets(self, new_channel_ids)
return sub_recording

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids):
"""
from spikeinterface import UnitsSelectionSorting

new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)]
new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)]
new_sorting = UnitsSelectionSorting(self, new_unit_ids)
return new_sorting

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def generate_sorting(
)

if empty_units is not None:
keep = ~np.in1d(labels, empty_units)
keep = ~np.isin(labels, empty_units)
times = times[keep]
labels = labels[keep]

Expand Down Expand Up @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
sample_index = spike["sample_index"]
if sample_index not in units_used_for_spike:
units_used_for_spike[sample_index] = np.array([spike["unit_index"]])
units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])]
units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])]

if len(units_not_used) == 0:
continue
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_ChannelSparsity():

for key, v in sparsity.unit_id_to_channel_ids.items():
assert key in unit_ids
assert np.all(np.in1d(v, channel_ids))
assert np.all(np.isin(v, channel_ids))

for key, v in sparsity.unit_id_to_channel_indices.items():
assert key in unit_ids
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
else:
# we cannot automatically find new names
new_unit_ids = [f"merge{i}" for i in range(num_merge)]
if np.any(np.in1d(new_unit_ids, keep_unit_ids)):
if np.any(np.isin(new_unit_ids, keep_unit_ids)):
raise ValueError(
"Unable to find 'new_unit_ids' because it is a string and parents "
"already contain merges. Pass a list of 'new_unit_ids' as an argument."
Expand All @@ -68,7 +68,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
# dtype int
new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
else:
if np.any(np.in1d(new_unit_ids, keep_unit_ids)):
if np.any(np.isin(new_unit_ids, keep_unit_ids)):
raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones")

assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge"
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids):
contact_ids = channels["contact_id"].values.astype("U")

# extracting information of requested channels
keep = np.in1d(channel_ids, recording_channel_ids)
keep = np.isin(channel_ids, recording_channel_ids)
channel_ids = channel_ids[keep]
contact_ids = contact_ids[keep]

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def _set_params(

def _select_extension_data(self, unit_ids):
old_unit_ids = self.waveform_extractor.sorting.unit_ids
unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids))
unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids))

spike_mask = np.in1d(self.spikes["unit_index"], unit_inds)
spike_mask = np.isin(self.spikes["unit_index"], unit_inds)
new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask]
return dict(amplitude_scalings=new_amplitude_scalings)

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids):
# load filter and save amplitude files
sorting = self.waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)
(keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids))
(keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids))

new_extension_data = dict()
for seg_index in range(sorting.get_num_segments()):
amp_data_name = f"amplitude_segment_{seg_index}"
amps = self._extension_data[amp_data_name]
filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices)
filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices)
new_extension_data[amp_data_name] = amps[filtered_idxs]
return new_extension_data

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth

def _select_extension_data(self, unit_ids):
old_unit_ids = self.waveform_extractor.sorting.unit_ids
unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids))
unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids))

spike_mask = np.in1d(self.spikes["unit_index"], unit_inds)
spike_mask = np.isin(self.spikes["unit_index"], unit_inds)
new_spike_locations = self._extension_data["spike_locations"][spike_mask]
return dict(spike_locations=new_spike_locations)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non

self.bad_channel_ids = bad_channel_ids
self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids)
self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs)
self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs)
self._bad_channel_idxs.setflags(write=False)

if sigma_um is None:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])]
spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)

Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def calculate_pc_metrics(
neighbor_unit_ids = unit_ids
neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids)

labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)]
pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
labels = all_labels[np.isin(all_labels, neighbor_unit_ids)]
pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
pcs_flat = pcs.reshape(pcs.shape[0], -1)

func_args = (
Expand Down Expand Up @@ -506,7 +506,7 @@ def nearest_neighbors_isolation(
other_units_ids = [
unit_id
for unit_id in other_units_ids
if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit))
if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit))
>= (n_channels_target_unit * min_spatial_overlap)
]

Expand Down Expand Up @@ -536,10 +536,10 @@ def nearest_neighbors_isolation(
if waveform_extractor.is_sparse():
# in this case, waveforms are sparse so we need to do some smart indexing
waveforms_target_unit_sampled = waveforms_target_unit_sampled[
:, :, np.in1d(closest_chans_target_unit, common_channel_idxs)
:, :, np.isin(closest_chans_target_unit, common_channel_idxs)
]
waveforms_other_unit_sampled = waveforms_other_unit_sampled[
:, :, np.in1d(closest_chans_other_unit, common_channel_idxs)
:, :, np.isin(closest_chans_other_unit, common_channel_idxs)
]
else:
waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine
seg_num = 0 # TODO: make compatible with multiple segments
idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label)
idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"]
intersection = np.where(np.in1d(idx_2, idx_1))[0]
intersection = np.where(np.isin(idx_2, idx_1))[0]
intersection = np.random.permutation(intersection)[:nb_spikes]
if len(intersection) == 0:
print(f"No {label}s found for unit {unit_id}")
Expand Down Expand Up @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos

for label in ["TP", "FN"]:
idx_1 = np.where(comp.get_labels1(unit_id) == label)[0]
intersection = np.where(np.in1d(idx_2, idx_1))[0]
intersection = np.where(np.isin(idx_2, idx_1))[0]
intersection = np.random.permutation(intersection)[:nb_spikes]
wfs_sliced = wfs[intersection, :, :]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2):
matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000))
self.good_matches = matches["index1"]

garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches)
garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches)
garbage_channels = self.peaks["channel_index"][garbage_matches]
garbage_peaks = times2[garbage_matches]
nb_garbage = len(garbage_peaks)
Expand Down Expand Up @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0),

idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"]
all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id)
mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx])
mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx])
colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask])
ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5)
x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean())
Expand All @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0),

idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"]
all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id)
mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx])
mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx])
colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask])
ax.scatter(
self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5
Expand Down Expand Up @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0),

idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"]
all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id)
mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx])
mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx])
colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask])
ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5)
x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _split_waveforms(
local_labels_with_noise = clustering[0]
cluster_probability = clustering[2]
(persistent_clusters,) = np.nonzero(cluster_probability > probability_thr)
local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1
local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1

# remove super small cluster
labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True)
Expand All @@ -43,7 +43,7 @@ def _split_waveforms(
to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio]
# ~ print('to_remove', to_remove, count / valid_size)
if to_remove.size > 0:
local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1
local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1

local_labels_with_noise[valid_size:] = -2

Expand Down Expand Up @@ -123,7 +123,7 @@ def _split_waveforms_nested(
active_labels_with_noise = clustering[0]
cluster_probability = clustering[2]
(persistent_clusters,) = np.nonzero(clustering[2] > probability_thr)
active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1
active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1

active_labels = active_labels_with_noise[active_ind < valid_size]
active_labels_set = np.unique(active_labels)
Expand Down Expand Up @@ -381,9 +381,9 @@ def auto_clean_clustering(
continue

wfs0 = wfs_arrays[label0]
wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)]
wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)]
wfs1 = wfs_arrays[label1]
wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)]
wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)]

# TODO : remove
assert wfs0.shape[2] == wfs1.shape[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d):
for chan_ind in prev_local_chan_inds:
if total_count[chan_ind] == 0:
continue
# ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0))
# ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0))
(inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0))
if inds.size <= d["min_spike_on_channel"]:
chan_amps[chan_ind] = 0.0
Expand Down Expand Up @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d):

(wf_chans,) = np.nonzero(sparsity_mask[chan_ind])
# TODO: only for debug, remove later
assert np.all(np.in1d(local_chan_inds, wf_chans))
assert np.all(np.isin(local_chan_inds, wf_chans))

# none label spikes
wfs_chan = wfs_chan[inds, :, :]
# only some channels
wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)]
wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)]
wfs.append(wfs_chan)

# put noise to enhance clusters
Expand Down Expand Up @@ -517,15 +517,15 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels,
(wf_chans,) = np.nonzero(sparsity_mask[chan_ind])
# print('wf_chans', wf_chans)
# TODO: only for debug, remove later
assert np.all(np.in1d(wanted_chans, wf_chans))
assert np.all(np.isin(wanted_chans, wf_chans))
wfs_chan = wfs_arrays[chan_ind]

# TODO: only for debug, remove later
assert wfs_chan.shape[0] == sel.size

wfs_chan = wfs_chan[inds, :, :]
# only some channels
wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)]
wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)]
wfs.append(wfs_chan)

wfs = np.concatenate(wfs, axis=0)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_waveforms_density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
if same_axis and not np.array_equal(chan_inds, shared_chan_inds):
# add more channels if necessary
wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float)
mask = np.in1d(shared_chan_inds, chan_inds)
mask = np.isin(shared_chan_inds, chan_inds)
wfs_[:, :, mask] = wfs
wfs_[:, :, ~mask] = np.nan
wfs = wfs_
Expand Down

0 comments on commit 5e420f3

Please sign in to comment.