Skip to content

Commit

Permalink
Merge branch 'main' into searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Sep 19, 2023
2 parents 9d07ec2 + 0adee81 commit 3e860d4
Show file tree
Hide file tree
Showing 24 changed files with 118 additions and 52 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/comparison/basecomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self):
indexes = np.arange(scores.shape[1])
order1 = []
for r in range(scores.shape[0]):
possible = indexes[~np.in1d(indexes, order1)]
possible = indexes[~np.isin(indexes, order1)]
if possible.size > 0:
ind = np.argmax(scores.iloc[r, possible].values)
order1.append(possible[ind])
remain = indexes[~np.in1d(indexes, order1)]
remain = indexes[~np.isin(indexes, order1)]
order1.extend(remain)
scores = scores.iloc[:, order1]

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun
matched_units2 = match_12[match_12 != -1].values

unmatched_units1 = match_12[match_12 == -1].index
unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)]
unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)]

ordered_units1 = np.hstack([matched_units1, unmatched_units1])
ordered_units2 = np.hstack([matched_units2, unmatched_units2])
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None):
def _remove_channels(self, remove_channel_ids):
from .channelslice import ChannelSliceRecording

new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)]
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)]
sub_recording = ChannelSliceRecording(self, new_channel_ids)
return sub_recording

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None):
def _remove_channels(self, remove_channel_ids):
from .channelslice import ChannelSliceSnippets

new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)]
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)]
sub_recording = ChannelSliceSnippets(self, new_channel_ids)
return sub_recording

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids):
"""
from spikeinterface import UnitsSelectionSorting

new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)]
new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)]
new_sorting = UnitsSelectionSorting(self, new_unit_ids)
return new_sorting

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def generate_sorting(
)

if empty_units is not None:
keep = ~np.in1d(labels, empty_units)
keep = ~np.isin(labels, empty_units)
times = times[keep]
labels = labels[keep]

Expand Down Expand Up @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
sample_index = spike["sample_index"]
if sample_index not in units_used_for_spike:
units_used_for_spike[sample_index] = np.array([spike["unit_index"]])
units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])]
units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])]

if len(units_not_used) == 0:
continue
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_ChannelSparsity():

for key, v in sparsity.unit_id_to_channel_ids.items():
assert key in unit_ids
assert np.all(np.in1d(v, channel_ids))
assert np.all(np.isin(v, channel_ids))

for key, v in sparsity.unit_id_to_channel_indices.items():
assert key in unit_ids
Expand Down
44 changes: 40 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
from typing import Iterable, Literal, Optional
import json
import os

import numpy as np
from copy import deepcopy
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
self._template_cache = {}
self._params = {}
self._loaded_extensions = dict()
self._is_read_only = False
self.sparsity = sparsity

self.folder = folder
Expand All @@ -103,6 +105,8 @@ def __init__(
if (self.folder / "params.json").is_file():
with open(str(self.folder / "params.json"), "r") as f:
self._params = json.load(f)
if not os.access(self.folder, os.W_OK):
self._is_read_only = True
else:
# this is in case of in-memory
self.format = "memory"
Expand Down Expand Up @@ -399,6 +403,9 @@ def return_scaled(self) -> bool:
def dtype(self):
return self._params["dtype"]

def is_read_only(self) -> bool:
return self._is_read_only

def has_recording(self) -> bool:
return self._recording is not None

Expand Down Expand Up @@ -516,6 +523,10 @@ def is_extension(self, extension_name) -> bool:
"""
if self.folder is None:
return extension_name in self._loaded_extensions

if extension_name in self._loaded_extensions:
# extension already loaded in memory
return True
else:
if self.format == "binary":
return (self.folder / extension_name).is_dir() and (
Expand Down Expand Up @@ -1740,13 +1751,33 @@ def __init__(self, waveform_extractor):
if self.format == "binary":
self.extension_folder = self.folder / self.extension_name
if not self.extension_folder.is_dir():
self.extension_folder.mkdir()
if self.waveform_extractor.is_read_only():
warn(
"WaveformExtractor: cannot save extension in read-only mode. "
"Extension will be saved in memory."
)
self.format = "memory"
self.extension_folder = None
self.folder = None
else:
self.extension_folder.mkdir()

else:
import zarr

zarr_root = zarr.open(self.folder, mode="r+")
mode = "r+" if not self.waveform_extractor.is_read_only() else "r"
zarr_root = zarr.open(self.folder, mode=mode)
if self.extension_name not in zarr_root.keys():
self.extension_group = zarr_root.create_group(self.extension_name)
if self.waveform_extractor.is_read_only():
warn(
"WaveformExtractor: cannot save extension in read-only mode. "
"Extension will be saved in memory."
)
self.format = "memory"
self.extension_folder = None
self.folder = None
else:
self.extension_group = zarr_root.create_group(self.extension_name)
else:
self.extension_group = zarr_root[self.extension_name]
else:
Expand Down Expand Up @@ -1863,6 +1894,9 @@ def save(self, **kwargs):
self._save(**kwargs)

def _save(self, **kwargs):
# Only save if not read only
if self.waveform_extractor.is_read_only():
return
if self.format == "binary":
import pandas as pd

Expand Down Expand Up @@ -1900,7 +1934,9 @@ def _save(self, **kwargs):
self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a"
store=self.extension_group.store,
group=f"{self.extension_group.name}/{ext_data_name}",
mode="a",
)
self.extension_group[ext_data_name].attrs["dataframe"] = True
else:
Expand Down
9 changes: 5 additions & 4 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting):
----------
parent_sorting: Recording
The sorting object
units_to_merge: list of lists
units_to_merge: list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids: None or list
Expand All @@ -24,6 +24,7 @@ class MergeUnitsSorting(BaseSorting):
Default: 'keep'
delta_time_ms: float or None
Number of ms to consider for duplicated spikes. None won't check for duplications
Returns
-------
sorting: Sorting
Expand All @@ -33,7 +34,7 @@ class MergeUnitsSorting(BaseSorting):
def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = parent_sorting

if not isinstance(units_to_merge[0], list):
if not isinstance(units_to_merge[0], (list, tuple)):
# keep backward compatibility : the previous behavior was only one merge
units_to_merge = [units_to_merge]

Expand All @@ -59,7 +60,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
else:
# we cannot automatically find new names
new_unit_ids = [f"merge{i}" for i in range(num_merge)]
if np.any(np.in1d(new_unit_ids, keep_unit_ids)):
if np.any(np.isin(new_unit_ids, keep_unit_ids)):
raise ValueError(
"Unable to find 'new_unit_ids' because it is a string and parents "
"already contain merges. Pass a list of 'new_unit_ids' as an argument."
Expand All @@ -68,7 +69,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
# dtype int
new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
else:
if np.any(np.in1d(new_unit_ids, keep_unit_ids)):
if np.any(np.isin(new_unit_ids, keep_unit_ids)):
raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones")

assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge"
Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ def export_to_phy(
templates[unit_ind, :, :][:, : len(chan_inds)] = template
templates_ind[unit_ind, : len(chan_inds)] = chan_inds

template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity")
if waveform_extractor.is_extension("similarity"):
tmc = waveform_extractor.load_extension("similarity")
template_similarity = tmc.get_data()
else:
template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity")

np.save(str(output_folder / "templates.npy"), templates)
np.save(str(output_folder / "template_ind.npy"), templates_ind)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids):
contact_ids = channels["contact_id"].values.astype("U")

# extracting information of requested channels
keep = np.in1d(channel_ids, recording_channel_ids)
keep = np.isin(channel_ids, recording_channel_ids)
channel_ids = channel_ids[keep]
contact_ids = contact_ids[keep]

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def _set_params(

def _select_extension_data(self, unit_ids):
old_unit_ids = self.waveform_extractor.sorting.unit_ids
unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids))
unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids))

spike_mask = np.in1d(self.spikes["unit_index"], unit_inds)
spike_mask = np.isin(self.spikes["unit_index"], unit_inds)
new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask]
return dict(amplitude_scalings=new_amplitude_scalings)

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids):
# load filter and save amplitude files
sorting = self.waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)
(keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids))
(keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids))

new_extension_data = dict()
for seg_index in range(sorting.get_num_segments()):
amp_data_name = f"amplitude_segment_{seg_index}"
amps = self._extension_data[amp_data_name]
filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices)
filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices)
new_extension_data[amp_data_name] = amps[filtered_idxs]
return new_extension_data

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth

def _select_extension_data(self, unit_ids):
old_unit_ids = self.waveform_extractor.sorting.unit_ids
unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids))
unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids))

spike_mask = np.in1d(self.spikes["unit_index"], unit_inds)
spike_mask = np.isin(self.spikes["unit_index"], unit_inds)
new_spike_locations = self._extension_data["spike_locations"][spike_mask]
return dict(spike_locations=new_spike_locations)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import numpy as np
import pandas as pd
import shutil
import platform
from pathlib import Path

from spikeinterface import extract_waveforms, load_extractor, compute_sparsity
from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity
from spikeinterface.extractors import toy_example

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -76,6 +77,16 @@ def setUp(self):
overwrite=True,
)
self.we2 = we2

# make we read-only
if platform.system() != "Windows":
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
if not we_ro_folder.is_dir():
shutil.copytree(we2.folder, we_ro_folder)
# change permissions (R+X)
we_ro_folder.chmod(0o555)
self.we_ro = load_waveforms(we_ro_folder)

self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30)
we_memory = extract_waveforms(
recording,
Expand All @@ -97,6 +108,12 @@ def setUp(self):
folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True
)

def tearDown(self):
# allow pytest to delete RO folder
if platform.system() != "Windows":
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
we_ro_folder.chmod(0o777)

def _test_extension_folder(self, we, in_memory=False):
if self.extension_function_kwargs_list is None:
extension_function_kwargs_list = [dict()]
Expand Down Expand Up @@ -177,3 +194,11 @@ def test_extension(self):
assert ext_data_mem.equals(ext_data_zarr)
else:
print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.")

# read-only - Extension is memory only
if platform.system() != "Windows":
_ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False)
assert self.extension_class.extension_name in self.we_ro.get_available_extension_names()
ext_ro = self.we_ro.load_extension(self.extension_class.extension_name)
assert ext_ro.format == "memory"
assert ext_ro.extension_folder is None
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/unit_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals
def get_grid_convolution_templates_and_weights(
contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50
):
import sklearn.metrics

x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max()
y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max()

Expand All @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights(
template_positions[:, 0] = all_x.flatten()
template_positions[:, 1] = all_y.flatten()

import sklearn

# mask to get nearest template given a channel
dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions)
nearest_template_mask = dist < radius_um
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non

self.bad_channel_ids = bad_channel_ids
self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids)
self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs)
self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs)
self._bad_channel_idxs.setflags(write=False)

if sigma_um is None:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])]
spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)

Expand Down
Loading

0 comments on commit 3e860d4

Please sign in to comment.