From 9b2875fd06782bd6da722fed1f524dfda4203f0a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 11:53:01 +0100 Subject: [PATCH 01/15] Add stream_mode as extra_requirements for NWB wghen streaming --- src/spikeinterface/extractors/nwbextractors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d797e64910..171992f6b1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -599,6 +599,8 @@ def __init__( else: gains, offsets, locations, groups = self._fetch_main_properties_backend() self.extra_requirements.append("h5py") + if stream_mode is not None: + self.extra_requirements.append(stream_mode) self.set_channel_gains(gains) self.set_channel_offsets(offsets) if locations is not None: @@ -1100,6 +1102,8 @@ def __init__( for property_name, property_values in properties.items(): values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] self.set_property(property_name, values) + if stream_mode is not None: + self.extra_requirements.append(stream_mode) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) From b0e7b1c60086fd818759b8df0f617d5441f4ff40 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 06:04:48 +0100 Subject: [PATCH 02/15] Fix kwargs in silence periods --- src/spikeinterface/preprocessing/silence_periods.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 85169011d8..a188f5d8db 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -97,7 +97,12 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, noise_generator=noise_generator) + self._kwargs = dict(recording=recording, + list_periods=list_periods, + mode=mode, + noise_levels=noise_levels, + seed=seed, + random_chunk_kwargs=random_chunk_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): From 60d7ad53b59fac5b47bd976905d64efc62b3daeb Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 06:10:02 +0100 Subject: [PATCH 03/15] Fix --- src/spikeinterface/preprocessing/silence_periods.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index a188f5d8db..5e410d51d5 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -100,7 +100,6 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, - noise_levels=noise_levels, seed=seed, random_chunk_kwargs=random_chunk_kwargs) From 01ae85cbb3ebcfd6f9e20eb191d92a278bbcb5e4 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 06:14:16 +0100 Subject: [PATCH 04/15] WIP --- src/spikeinterface/preprocessing/silence_periods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 5e410d51d5..7c518d02a0 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -100,8 +100,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, - seed=seed, - random_chunk_kwargs=random_chunk_kwargs) + seed=seed) + self._kwargs.update(random_chunk_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): From db2b4d5130a095500637c06f9d74aa2f03d41b73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 05:15:54 +0000 Subject: [PATCH 05/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/silence_periods.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 7c518d02a0..00d9a1a407 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -97,10 +97,7 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, - list_periods=list_periods, - mode=mode, - seed=seed) + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) self._kwargs.update(random_chunk_kwargs) From 5af4f858268c18421c8d1e7a3cbaca3c9957491e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:19:15 +0000 Subject: [PATCH 06/15] Hard code synchony_size for users, but leave flexible code underneathe --- doc/get_started/quickstart.rst | 2 +- doc/modules/qualitymetrics/synchrony.rst | 4 +- .../qualitymetrics/misc_metrics.py | 27 ++++++------- .../tests/test_metrics_functions.py | 39 ++++++++----------- 4 files changed, 30 insertions(+), 42 deletions(-) diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3d45606a78..1349802ce5 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions 'min_spikes': 0, 'window_size_s': 1}, 'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'}, - 'synchrony': {'synchrony_sizes': (2, 4, 8)}} + 'synchrony': {} Since the recording is very short, let’s change some parameters to diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index d244fd0c0f..696dacbd3c 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, within and across spike trains. -Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count. +Synchrony metrics are computed for 2, 4 and 8 synchronous spikes. @@ -29,7 +29,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Combine a sorting and recording into a sorting_analyzer - synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dfd41cf88..b0e0a0ad19 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -520,7 +520,7 @@ def compute_sliding_rp_violations( ) -def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): +def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])): """ Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. @@ -528,10 +528,10 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): ---------- spikes : np.array Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). - synchrony_sizes : numpy array - The synchrony sizes to compute. Should be pre-sorted. all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. + synchrony_sizes : numpy array + The synchrony sizes to compute. Should be pre-sorted. Returns ------- @@ -565,17 +565,15 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of - "synchrony_size" spikes at the exact same sample index. + spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. Parameters ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. - synchrony_sizes : list or tuple, default: (2, 4, 8) - The synchrony sizes to compute. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. @@ -583,19 +581,16 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - Returns are as many as synchrony_sizes. References ---------- Based on concepts described in [Grün]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" - # Sort the synchrony times so we can slice numpy arrays, instead of using dicts - synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16) - synchrony_sizes_np.sort() - res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np]) + synchrony_sizes = np.array([2, 4, 8]) + + res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting @@ -606,10 +601,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids - synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids) + synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes) synchrony_metrics_dict = {} - for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): + for sync_idx, synchrony_size in enumerate(synchrony_sizes): sync_id_metrics_dict = {} for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: @@ -623,7 +618,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ return res(**synchrony_metrics_dict) -_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) +_default_params["synchrony"] = dict() def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..f51dc3e884 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -39,7 +39,7 @@ compute_firing_ranges, compute_amplitude_cv_metrics, compute_sd_ratio, - get_synchrony_counts, + _get_synchrony_counts, compute_quality_metrics, ) @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync(): one_spike["sample_index"] = spike_times one_spike["unit_index"] = spike_units - sync_count = get_synchrony_counts(one_spike, np.array((2)), [0]) + sync_count = _get_synchrony_counts(one_spike, [0]) assert np.all(sync_count[0] == np.array([0])) @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync(): two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1]) + sync_count = _get_synchrony_counts(two_spikes, [0, 1]) assert np.all(sync_count[0] == np.array([1, 1])) @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync(): four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3]) + sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3]) assert np.all(sync_count[0] == np.array([1, 1, 1, 1])) assert np.all(sync_count[1] == np.array([1, 1, 1, 1])) @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units(): three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2]) + sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2]) assert np.all(sync_count[0] == np.array([0, 1, 1])) @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations): def test_synchrony_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting - synchrony_sizes = (2, 3, 4) - synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes) - print(synchrony_metrics) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + + synchrony_sizes = np.array([2, 4, 8]) # check returns for size in synchrony_sizes: @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple): sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory") - previous_synchrony_metrics = compute_synchrony_metrics( - previous_sorting_analyzer, synchrony_sizes=synchrony_sizes - ) - current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes) + previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer) + current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync) print(current_synchrony_metrics) # check that all values increased for i, col in enumerate(previous_synchrony_metrics._fields): @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple): unit_ids_subset = [3, 7] - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics( - sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset - ) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset) - assert list(synchrony_metrics.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7] def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple): - # all_unit_ids = sorting_analyzer_simple.sorting.unit_ids - - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes) - - assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple) + assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids) @pytest.mark.sortingcomponents From 45eb5b74e58061ee04dcb2a4bba10dbcf2a2c892 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:31:23 +0000 Subject: [PATCH 07/15] Add warning and ability to pass synchrony_sizes --- src/spikeinterface/qualitymetrics/__init__.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/__init__.py b/src/spikeinterface/qualitymetrics/__init__.py index 9d604f6ae2..754c82d8e3 100644 --- a/src/spikeinterface/qualitymetrics/__init__.py +++ b/src/spikeinterface/qualitymetrics/__init__.py @@ -6,4 +6,4 @@ get_default_qm_params, ) from .pca_metrics import get_quality_pca_metric_list -from .misc_metrics import get_synchrony_counts +from .misc_metrics import _get_synchrony_counts diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index b0e0a0ad19..2f178c46f3 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -565,7 +565,7 @@ def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, return synchrony_counts -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -588,6 +588,10 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None): This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ + if synchrony_sizes is not None: + warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" + warnings.warn(warning_message) + synchrony_sizes = np.array([2, 4, 8]) res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) From 2081916e33d467145223a7c3099aca556f6e3864 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:18:15 +0000 Subject: [PATCH 08/15] respond to review --- src/spikeinterface/qualitymetrics/misc_metrics.py | 10 ++++++---- .../qualitymetrics/tests/test_metrics_functions.py | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2f178c46f3..6007de379c 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -520,7 +520,7 @@ def compute_sliding_rp_violations( ) -def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])): +def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): """ Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. @@ -530,7 +530,7 @@ def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. - synchrony_sizes : numpy array + synchrony_sizes : None or np.array, default: None The synchrony sizes to compute. Should be pre-sorted. Returns @@ -576,6 +576,8 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + synchrony_sizes: None, default: None + Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. Returns ------- @@ -590,7 +592,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N if synchrony_sizes is not None: warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" - warnings.warn(warning_message) + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) synchrony_sizes = np.array([2, 4, 8]) @@ -605,7 +607,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids - synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes) + synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids) synchrony_metrics_dict = {} for sync_idx, synchrony_size in enumerate(synchrony_sizes): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index f51dc3e884..ae4c7ab62d 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync(): one_spike["sample_index"] = spike_times one_spike["unit_index"] = spike_units - sync_count = _get_synchrony_counts(one_spike, [0]) + sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0]) assert np.all(sync_count[0] == np.array([0])) @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync(): two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = _get_synchrony_counts(two_spikes, [0, 1]) + sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1]) assert np.all(sync_count[0] == np.array([1, 1])) @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync(): four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3]) + sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3]) assert np.all(sync_count[0] == np.array([1, 1, 1, 1])) assert np.all(sync_count[1] == np.array([1, 1, 1, 1])) @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units(): three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2]) + sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2]) assert np.all(sync_count[0] == np.array([0, 1, 1])) From 09ff624817b53d30d35f1e4f9060edabab45a308 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Dec 2024 10:33:40 +0100 Subject: [PATCH 09/15] Remove venv in full-tests-with-codecov --- .../actions/build-test-environment/action.yml | 36 +++++++------------ .github/workflows/all-tests.yml | 2 +- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 723e8a702f..a212bd64d5 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -1,41 +1,20 @@ name: Install packages description: This action installs the package and its dependencies for testing -inputs: - python-version: - description: 'Python version to set up' - required: false - os: - description: 'Operating system to set up' - required: false - runs: using: "composite" steps: - name: Install dependencies run: | - sudo apt install git git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" - python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step - python -m pip install -U pip # Official recommended way - source ${{ github.workspace }}/test_env/bin/activate pip install tabulate # This produces summaries at the end pip install -e .[test,extractors,streaming_extractors,test_extractors,full] shell: bash - - name: Force installation of latest dev from key-packages when running dev (not release) - run: | - source ${{ github.workspace }}/test_env/bin/activate - spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") - if [ $spikeinterface_is_dev_version = "True" ]; then - echo "Running spikeinterface dev version" - pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo - pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface - fi - echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" + - name: Install git-annex shell: bash - - name: git-annex install run: | + pip install datalad-installer wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz mkdir /home/runner/work/installation mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ @@ -44,4 +23,15 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH cd $workdir + git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + - name: Force installation of latest dev from key-packages when running dev (not release) + run: | + source ${{ github.workspace }}/test_env/bin/activate + spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") + if [ $spikeinterface_is_dev_version = "True" ]; then + echo "Running spikeinterface dev version" + pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo + pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface + fi + echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" shell: bash diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index dcaec8b272..a9c840d5d5 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -47,7 +47,7 @@ jobs: echo "$file was changed" done - - name: Set testing environment # This decides which tests are run and whether to install especial dependencies + - name: Set testing environment # This decides which tests are run and whether to install special dependencies shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" From 8500b9d0f4488794dcc6d6b71afec2ebf4697b1d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Dec 2024 10:48:11 +0100 Subject: [PATCH 10/15] Oups --- .github/actions/build-test-environment/action.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index a212bd64d5..c2524d2c16 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -26,7 +26,6 @@ runs: git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency - name: Force installation of latest dev from key-packages when running dev (not release) run: | - source ${{ github.workspace }}/test_env/bin/activate spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") if [ $spikeinterface_is_dev_version = "True" ]; then echo "Running spikeinterface dev version" From 922606b6d4d279da103b7e7edde3ecb79a76e3c8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Dec 2024 11:16:01 +0100 Subject: [PATCH 11/15] Oups 2 --- .github/workflows/full-test-with-codecov.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 407c614ebf..f8ed2aa7a9 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -45,7 +45,6 @@ jobs: env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell run: | - source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY From 986a74a30c94a49ed2a2dd6183e8ddc078105b85 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Dec 2024 09:02:31 +0100 Subject: [PATCH 12/15] Pin ONE-API version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fc09ad9198..22fbdc7f22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ extractors = [ ] streaming_extractors = [ - "ONE-api>=2.7.0", # alf sorter and streaming IBL + "ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", From 96da22f7ac509bfc83a2a90eed06d58e3f71f990 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 5 Dec 2024 10:00:30 +0000 Subject: [PATCH 13/15] Correct method default in docstring --- src/spikeinterface/postprocessing/unit_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 3f6dd47eec..bea06fd8f5 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + method : "monopolar_triangulation" or "center_of_mass" or "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. From 10d459f3d45315cee3079e3b14428222487ef9c6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:18:59 +0000 Subject: [PATCH 14/15] change or to | in docstring --- src/spikeinterface/postprocessing/unit_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index bea06fd8f5..df19458316 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method : "monopolar_triangulation" or "center_of_mass" or "grid_convolution", default: "monopolar_triangulation" + method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. From 4c7b6a5be65af4aa4ce6461e84956455f970942f Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 11 Dec 2024 09:53:59 +0100 Subject: [PATCH 15/15] Patch --- src/spikeinterface/widgets/unit_waveforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index c593836061..3b31eacee5 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -565,7 +565,7 @@ def _update_plot(self, change): channel_locations = self.sorting_analyzer.get_channel_locations() else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] - templates = self.templates.templates_array[unit_indices] + templates = self.templates.get_dense_templates()[unit_indices] templates_shadings = None channel_locations = self.templates.get_channel_locations()