Skip to content

Commit

Permalink
Merge branch 'main' into fix-3540
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Dec 10, 2024
2 parents 9b2875f + 665adf8 commit ab76b6b
Show file tree
Hide file tree
Showing 21 changed files with 425 additions and 215 deletions.
35 changes: 12 additions & 23 deletions .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
@@ -1,41 +1,20 @@
name: Install packages
description: This action installs the package and its dependencies for testing

inputs:
python-version:
description: 'Python version to set up'
required: false
os:
description: 'Operating system to set up'
required: false

runs:
using: "composite"
steps:
- name: Install dependencies
run: |
sudo apt install git
git config --global user.email "[email protected]"
git config --global user.name "CI Almighty"
python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step
python -m pip install -U pip # Official recommended way
source ${{ github.workspace }}/test_env/bin/activate
pip install tabulate # This produces summaries at the end
pip install -e .[test,extractors,streaming_extractors,test_extractors,full]
shell: bash
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
source ${{ github.workspace }}/test_env/bin/activate
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface
fi
echo "Running tests for release, using pyproject.toml versions of neo and probeinterface"
- name: Install git-annex
shell: bash
- name: git-annex install
run: |
pip install datalad-installer
wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz
mkdir /home/runner/work/installation
mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/
Expand All @@ -44,4 +23,14 @@ runs:
tar xvzf git-annex-standalone-amd64.tar.gz
echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH
cd $workdir
git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface
fi
echo "Running tests for release, using pyproject.toml versions of neo and probeinterface"
shell: bash
2 changes: 1 addition & 1 deletion .github/workflows/all-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
echo "$file was changed"
done
- name: Set testing environment # This decides which tests are run and whether to install especial dependencies
- name: Set testing environment # This decides which tests are run and whether to install special dependencies
shell: bash
run: |
changed_files="${{ steps.changed-files.outputs.all_changed_files }}"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/full-test-with-codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
env:
HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell
run: |
source ${{ github.workspace }}/test_env/bin/activate
pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1
echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY
python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY
Expand Down
2 changes: 1 addition & 1 deletion doc/get_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions
'min_spikes': 0,
'window_size_s': 1},
'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'},
'synchrony': {'synchrony_sizes': (2, 4, 8)}}
'synchrony': {}
Since the recording is very short, let’s change some parameters to
Expand Down
4 changes: 2 additions & 2 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.
Synchrony metrics are computed for 2, 4 and 8 synchronous spikes.



Expand All @@ -29,7 +29,7 @@ Example code
import spikeinterface.qualitymetrics as sqm
# Combine a sorting and recording into a sorting_analyzer
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8))
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer)
# synchrony is a tuple of dicts with the synchrony metrics for each unit
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ extractors = [
]

streaming_extractors = [
"ONE-api>=2.7.0", # alf sorter and streaming IBL
"ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL
"ibllib>=2.36.0", # streaming IBL
# Following dependencies are for streaming with nwb files
"pynwb>=2.6.0",
Expand Down
14 changes: 5 additions & 9 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import warnings
from pathlib import Path

Expand All @@ -7,14 +8,9 @@

from .base import BaseSegment
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import (
convert_bytes_to_str,
convert_seconds_to_str,
)
from .recording_tools import write_binary_recording


from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs
from .recording_tools import write_binary_recording


class BaseRecording(BaseRecordingSnippets):
Expand Down Expand Up @@ -950,11 +946,11 @@ def time_to_sample_index(self, time_s):
sample_index = time_s * self.sampling_frequency
else:
sample_index = (time_s - self.t_start) * self.sampling_frequency
sample_index = round(sample_index)
sample_index = np.round(sample_index).astype(int)
else:
sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1

return int(sample_index)
return sample_index

def get_num_samples(self) -> int:
"""Returns the number of samples in this signal segment
Expand Down
121 changes: 77 additions & 44 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,10 @@ class ComputeTemplateMetrics(AnalyzerExtension):
include_multi_channel_metrics : bool, default: False
Whether to compute multi-channel metrics
delete_existing_metrics : bool, default: False
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged.
metrics_kwargs : dict
Additional arguments to pass to the metric functions. Including:
* recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7
* peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2
* peak_width_ms: the width in samples to detect peaks, default: 0.2
* depth_direction: the direction to compute velocity above and below, default: "y" (see notes)
* min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5
* min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7
* exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp"
* min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5
* spread_threshold: the threshold to compute the spread, default: 0.2
* spread_smooth_um: the smoothing in um to compute the spread, default: 20
* column_range: the range in um in the horizontal direction to consider channels for velocity, default: None
- If None, all channels all channels are considered
- If 0 or 1, only the "column" that includes the max channel is considered
- If > 1, only channels within range (+/-) um from the max channel horizontal position are used
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged.
metric_params : dict of dicts or None, default: None
Dictionary with parameters for template metrics calculation.
Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()`
Returns
-------
Expand All @@ -100,15 +87,29 @@ class ComputeTemplateMetrics(AnalyzerExtension):
need_recording = False
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

min_channels_for_multi_channel_warning = 10

def _handle_backward_compatibility_on_load(self):

# For backwards compatibility - this reformats metrics_kwargs as metric_params
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None:

metric_params = {}
for metric_name in self.params["metric_names"]:
metric_params[metric_name] = deepcopy(metrics_kwargs)
self.params["metric_params"] = metric_params

del self.params["metrics_kwargs"]

def _set_params(
self,
metric_names=None,
peak_sign="neg",
upsampling_factor=10,
sparsity=None,
metric_params=None,
metrics_kwargs=None,
include_multi_channel_metrics=False,
delete_existing_metrics=False,
Expand All @@ -134,33 +135,24 @@ def _set_params(
if include_multi_channel_metrics:
metric_names += get_multi_channel_template_metric_names()

if metrics_kwargs is None:
metrics_kwargs_ = _default_function_kwargs.copy()
if len(other_kwargs) > 0:
for m in other_kwargs:
if m in metrics_kwargs_:
metrics_kwargs_[m] = other_kwargs[m]
else:
metrics_kwargs_ = _default_function_kwargs.copy()
metrics_kwargs_.update(metrics_kwargs)
if metrics_kwargs is not None and metric_params is None:
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead"
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead"

metric_params = {}
for metric_name in metric_names:
metric_params[metric_name] = deepcopy(metrics_kwargs)

metric_params_ = get_default_tm_params(metric_names)
for k in metric_params_:
if metric_params is not None and k in metric_params:
metric_params_[k].update(metric_params[k])

metrics_to_compute = metric_names
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:

existing_params = tm_extension.params["metrics_kwargs"]
# checks that existing metrics were calculated using the same params
if existing_params != metrics_kwargs_:
warnings.warn(
f"The parameters used to calculate the previous template metrics are different"
f"than those used now.\nPrevious parameters: {existing_params}\nCurrent "
f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..."
)
tm_extension.params["metric_names"] = []
existing_metric_names = []
else:
existing_metric_names = tm_extension.params["metric_names"]

existing_metric_names = tm_extension.params["metric_names"]
existing_metric_names_propogated = [
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
]
Expand All @@ -171,7 +163,7 @@ def _set_params(
sparsity=sparsity,
peak_sign=peak_sign,
upsampling_factor=int(upsampling_factor),
metrics_kwargs=metrics_kwargs_,
metric_params=metric_params_,
delete_existing_metrics=delete_existing_metrics,
metrics_to_compute=metrics_to_compute,
)
Expand Down Expand Up @@ -273,7 +265,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
**self.params["metric_params"][metric_name],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
Expand Down Expand Up @@ -312,7 +304,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
**self.params["metric_params"][metric_name],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
Expand All @@ -326,8 +318,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri

def _run(self, verbose=False):

delete_existing_metrics = self.params["delete_existing_metrics"]
metrics_to_compute = self.params["metrics_to_compute"]
delete_existing_metrics = self.params["delete_existing_metrics"]

# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
Expand All @@ -343,9 +335,21 @@ def _run(self, verbose=False):
):
existing_metrics = tm_extension.params["metric_names"]

existing_metrics = []
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset)
tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None)
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metrics_to_compute):
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name]
# some metrics names produce data columns with other names. This deals with that.
for column_name in tm_compute_name_to_column_names[metric_name]:
computed_metrics[column_name] = tm_extension.data["metrics"][column_name]

self.data["metrics"] = computed_metrics

Expand All @@ -372,6 +376,35 @@ def _get_data(self):
)


def get_default_tm_params(metric_names):
if metric_names is None:
metric_names = get_template_metric_names()

base_tm_params = _default_function_kwargs

metric_params = {}
for metric_name in metric_names:
metric_params[metric_name] = deepcopy(base_tm_params)

return metric_params


# a dict converting the name of the metric for computation to the output of that computation
tm_compute_name_to_column_names = {
"peak_to_valley": ["peak_to_valley"],
"peak_trough_ratio": ["peak_trough_ratio"],
"half_width": ["half_width"],
"repolarization_slope": ["repolarization_slope"],
"recovery_slope": ["recovery_slope"],
"num_positive_peaks": ["num_positive_peaks"],
"num_negative_peaks": ["num_negative_peaks"],
"velocity_above": ["velocity_above"],
"velocity_below": ["velocity_below"],
"exp_decay": ["exp_decay"],
"spread": ["spread"],
}


def get_trough_and_peak_idx(template):
"""
Return the indices into the input template of the detected trough
Expand Down
Loading

0 comments on commit ab76b6b

Please sign in to comment.