Skip to content

Commit

Permalink
Merge branch 'main' into curation_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow authored Dec 12, 2024
2 parents be79a83 + 6fde997 commit 77b8c81
Show file tree
Hide file tree
Showing 13 changed files with 59 additions and 72 deletions.
35 changes: 12 additions & 23 deletions .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
@@ -1,41 +1,20 @@
name: Install packages
description: This action installs the package and its dependencies for testing

inputs:
python-version:
description: 'Python version to set up'
required: false
os:
description: 'Operating system to set up'
required: false

runs:
using: "composite"
steps:
- name: Install dependencies
run: |
sudo apt install git
git config --global user.email "[email protected]"
git config --global user.name "CI Almighty"
python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step
python -m pip install -U pip # Official recommended way
source ${{ github.workspace }}/test_env/bin/activate
pip install tabulate # This produces summaries at the end
pip install -e .[test,extractors,streaming_extractors,test_extractors,full]
shell: bash
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
source ${{ github.workspace }}/test_env/bin/activate
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface
fi
echo "Running tests for release, using pyproject.toml versions of neo and probeinterface"
- name: Install git-annex
shell: bash
- name: git-annex install
run: |
pip install datalad-installer
wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz
mkdir /home/runner/work/installation
mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/
Expand All @@ -44,4 +23,14 @@ runs:
tar xvzf git-annex-standalone-amd64.tar.gz
echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH
cd $workdir
git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface
fi
echo "Running tests for release, using pyproject.toml versions of neo and probeinterface"
shell: bash
2 changes: 1 addition & 1 deletion .github/workflows/all-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
echo "$file was changed"
done
- name: Set testing environment # This decides which tests are run and whether to install especial dependencies
- name: Set testing environment # This decides which tests are run and whether to install special dependencies
shell: bash
run: |
changed_files="${{ steps.changed-files.outputs.all_changed_files }}"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/full-test-with-codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
env:
HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell
run: |
source ${{ github.workspace }}/test_env/bin/activate
pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1
echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY
python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY
Expand Down
2 changes: 1 addition & 1 deletion doc/get_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions
'min_spikes': 0,
'window_size_s': 1},
'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'},
'synchrony': {'synchrony_sizes': (2, 4, 8)}}
'synchrony': {}
Since the recording is very short, let’s change some parameters to
Expand Down
4 changes: 2 additions & 2 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.
Synchrony metrics are computed for 2, 4 and 8 synchronous spikes.



Expand All @@ -29,7 +29,7 @@ Example code
import spikeinterface.qualitymetrics as sqm
# Combine a sorting and recording into a sorting_analyzer
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8))
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer)
# synchrony is a tuple of dicts with the synchrony metrics for each unit
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ extractors = [
]

streaming_extractors = [
"ONE-api>=2.7.0", # alf sorter and streaming IBL
"ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL
"ibllib>=2.36.0", # streaming IBL
# Following dependencies are for streaming with nwb files
"pynwb>=2.6.0",
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,8 @@ def __init__(
else:
gains, offsets, locations, groups = self._fetch_main_properties_backend()
self.extra_requirements.append("h5py")
if stream_mode is not None:
self.extra_requirements.append(stream_mode)
self.set_channel_gains(gains)
self.set_channel_offsets(offsets)
if locations is not None:
Expand Down Expand Up @@ -1100,6 +1102,8 @@ def __init__(
for property_name, property_values in properties.items():
values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values]
self.set_property(property_name, values)
if stream_mode is not None:
self.extra_requirements.append(stream_mode)

if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/unit_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension):
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object
method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass"
method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation"
The method to use for localization
**method_kwargs : dict, default: {}
Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`.
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/preprocessing/silence_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see
rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index)
self.add_recording_segment(rec_segment)

self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, noise_generator=noise_generator)
self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed)
self._kwargs.update(random_chunk_kwargs)


class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/qualitymetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
get_default_qm_params,
)
from .pca_metrics import get_quality_pca_metric_list
from .misc_metrics import get_synchrony_counts
from .misc_metrics import _get_synchrony_counts
33 changes: 17 additions & 16 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,18 +520,18 @@ def compute_sliding_rp_violations(
)


def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
"""
Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`.
Parameters
----------
spikes : np.array
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
synchrony_sizes : numpy array
The synchrony sizes to compute. Should be pre-sorted.
all_unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. Expecting all units.
synchrony_sizes : None or np.array, default: None
The synchrony sizes to compute. Should be pre-sorted.
Returns
-------
Expand Down Expand Up @@ -565,37 +565,38 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
return synchrony_counts


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None):
def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None):
"""
Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.
spikes at the exact same sample index, with synchrony sizes 2, 4 and 8.
Parameters
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
synchrony_sizes : list or tuple, default: (2, 4, 8)
The synchrony sizes to compute.
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. If None, all units are used.
synchrony_sizes: None, default: None
Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes.
Returns
-------
sync_spike_{X} : dict
The synchrony metric for synchrony size X.
Returns are as many as synchrony_sizes.
References
----------
Based on concepts described in [Grün]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16)
synchrony_sizes_np.sort()

res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np])
if synchrony_sizes is not None:
warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`"
warnings.warn(warning_message, DeprecationWarning, stacklevel=2)

synchrony_sizes = np.array([2, 4, 8])

res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes])

sorting = sorting_analyzer.sorting

Expand All @@ -606,10 +607,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_

spikes = sorting.to_spike_vector()
all_unit_ids = sorting.unit_ids
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)
synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids)

synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
for sync_idx, synchrony_size in enumerate(synchrony_sizes):
sync_id_metrics_dict = {}
for i, unit_id in enumerate(all_unit_ids):
if unit_id not in unit_ids:
Expand All @@ -623,7 +624,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
return res(**synchrony_metrics_dict)


_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))
_default_params["synchrony"] = dict()


def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None):
Expand Down
39 changes: 16 additions & 23 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
get_synchrony_counts,
_get_synchrony_counts,
compute_quality_metrics,
)

Expand Down Expand Up @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync():
one_spike["sample_index"] = spike_times
one_spike["unit_index"] = spike_units

sync_count = get_synchrony_counts(one_spike, np.array((2)), [0])
sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0])

assert np.all(sync_count[0] == np.array([0]))

Expand All @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync():
two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1])
sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1])

assert np.all(sync_count[0] == np.array([1, 1]))

Expand All @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync():
four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3])
sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3])

assert np.all(sync_count[0] == np.array([1, 1, 1, 1]))
assert np.all(sync_count[1] == np.array([1, 1, 1, 1]))
Expand All @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units():
three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2])
sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2])

assert np.all(sync_count[0] == np.array([0, 1, 1]))

Expand Down Expand Up @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations):
def test_synchrony_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
sorting = sorting_analyzer.sorting
synchrony_sizes = (2, 3, 4)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes)
print(synchrony_metrics)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer)

synchrony_sizes = np.array([2, 4, 8])

# check returns
for size in synchrony_sizes:
Expand All @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple):
sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level)
sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory")

previous_synchrony_metrics = compute_synchrony_metrics(
previous_sorting_analyzer, synchrony_sizes=synchrony_sizes
)
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes)
previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer)
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync)
print(current_synchrony_metrics)
# check that all values increased
for i, col in enumerate(previous_synchrony_metrics._fields):
Expand All @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple):

unit_ids_subset = [3, 7]

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(
sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset
)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset)

assert list(synchrony_metrics.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7]


def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple):

# all_unit_ids = sorting_analyzer_simple.sorting.unit_ids

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)

assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple)
assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids)


@pytest.mark.sortingcomponents
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def _update_plot(self, change):
channel_locations = self.sorting_analyzer.get_channel_locations()
else:
unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids]
templates = self.templates.templates_array[unit_indices]
templates = self.templates.get_dense_templates()[unit_indices]
templates_shadings = None
channel_locations = self.templates.get_channel_locations()

Expand Down

0 comments on commit 77b8c81

Please sign in to comment.