Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of GroundTruthStudy #1983

Merged
merged 31 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0acc125
Start GroundTruthStudy refactoring.
samuelgarcia Sep 8, 2023
462961f
new GroundTruthStudy wip
samuelgarcia Sep 8, 2023
e0af88d
Make internal sorters able to be run with none dumpable to json recor…
samuelgarcia Sep 8, 2023
9905bf5
wip
samuelgarcia Sep 8, 2023
98fa0f8
gt_study wip
samuelgarcia Sep 9, 2023
f0940a5
gt study wip
samuelgarcia Sep 9, 2023
b0267dc
Add levels concept in GTStudy
samuelgarcia Sep 10, 2023
0750638
wip gtstudy
samuelgarcia Sep 11, 2023
ee2eb2f
STart porting matplotlib widgets related to ground truth study.
samuelgarcia Sep 12, 2023
d80341c
remove gtstudy widgets from legacy and port some of then in the API.
samuelgarcia Sep 12, 2023
f97f76a
Clean
samuelgarcia Sep 12, 2023
cf9a3b5
Merge branch 'run_sorter_jobs' of github.com:samuelgarcia/spikeinterf…
samuelgarcia Sep 13, 2023
ba2e961
small fix
samuelgarcia Sep 13, 2023
9b5b28b
small fix
samuelgarcia Sep 13, 2023
8d9ce49
group in same file CollisionGTComparison and CollisionGTStudy
samuelgarcia Sep 19, 2023
b1297e6
Update CollisionGTStudy and CorrelogramGTStudy
samuelgarcia Sep 19, 2023
8a7a90e
wip
samuelgarcia Sep 19, 2023
fe6f60f
Re move studytools.py. Not needed anymore.
samuelgarcia Sep 19, 2023
77505ad
rm studytools part2
samuelgarcia Sep 19, 2023
b5376a9
Modify doc for gt study
samuelgarcia Sep 19, 2023
469b520
merge with main
samuelgarcia Sep 19, 2023
d7aaa95
gt study widget xlim
samuelgarcia Sep 19, 2023
4b994cc
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Sep 22, 2023
5029445
Apply suggestions from code review
samuelgarcia Sep 26, 2023
32d3d7a
extract_waveforms_gt must be done on dataset key instead of case key.
samuelgarcia Sep 26, 2023
d48cd68
implement some TODOs
samuelgarcia Sep 27, 2023
af72fbc
oups
samuelgarcia Sep 27, 2023
c0c2163
merge with main
samuelgarcia Sep 27, 2023
6c561f2
more fix after merge with main and the new pickle to file mechanism
samuelgarcia Sep 27, 2023
e8c4b77
Merge branch 'main' into gt_study
samuelgarcia Sep 27, 2023
cb9a228
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions doc/modules/comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,21 +248,19 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for
We also have a high level class to compare many sorters against ground truth:
:py:func:`~spiekinterface.comparison.GroundTruthStudy()`

A study is a systematic performance comparison of several ground truth recordings with several sorters.
A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases
like the different parameter sets.

The study class proposes high-level tool functions to run many ground truth comparisons with many sorters
The study class proposes high-level tool functions to run many ground truth comparisons with many "cases"
on many recordings and then collect and aggregate results in an easy way.

The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolder:

* raw_files : contain a copy of recordings in binary format
* sorter_folders : contains outputs of sorters
* ground_truth : contains a copy of sorting ground truth in npz format
* sortings: contains light copy of all sorting in npz format
* tables: some tables in csv format

In order to run and rerun the computation all gt_sorting and recordings are copied to a fast and universal format:
binary (for recordings) and npz (for sortings).
* datasets: contains ground truth datasets
* sorters : contains outputs of sorters
* sortings: contains light copy of all sorting
* metrics: contains metrics
* ...


.. code-block:: python
Expand All @@ -274,28 +272,51 @@ binary (for recordings) and npz (for sortings).
import spikeinterface.widgets as sw
from spikeinterface.comparison import GroundTruthStudy

# Setup study folder
rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1)
rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1)
gt_dict = {
'rec0': (rec0, gt_sorting0),
'rec1': (rec1, gt_sorting1),

# generate 2 simulated datasets (could be also mearec files)
rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42)
rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91)

datasets = {
"toy0": (rec0, gt_sorting0),
"toy1": (rec1, gt_sorting1),
}
study_folder = 'a_study_folder'
study = GroundTruthStudy.create(study_folder, gt_dict)

# all sorters for all recordings in one function.
sorter_list = ['herdingspikes', 'tridesclous', ]
study.run_sorters(sorter_list, mode_if_folder_exists="keep")
# define some "cases" here we want to tests tridesclous2 on 2 datasets and spykingcircus on one dataset
# so it is a two level study (sorter_name, dataset)
# this could be more complicated like (sorter_name, dataset, params)
cases = {
("tdc2", "toy0"): {
"label": "tridesclous2 on tetrode0",
"dataset": "toy0",
"run_sorter_params": {
"sorter_name": "tridesclous2",
},
},
("tdc2", "toy1"): {
"label": "tridesclous2 on tetrode1",
"dataset": "toy1",
"run_sorter_params": {
"sorter_name": "tridesclous2",
},
},

("sc", "toy0"): {
"label": "spykingcircus2 on tetrode0",
"dataset": "toy0",
"run_sorter_params": {
"sorter_name": "spykingcircus",
"docker_image": True
},
},
}
# this initilize a folder
study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases,
levels=["sorter_name", "dataset"])

# You can re-run **run_study_sorters** as many times as you want.
# By default **mode='keep'** so only uncomputed sorters are re-run.
# For instance, just remove the "sorter_folders/rec1/herdingspikes" to re-run
# only one sorter on one recording.
#
# Then we copy the spike sorting outputs into a separate subfolder.
# This allow us to remove the "large" sorter_folders.
study.copy_sortings()

# all cases in one function
study.run_sorters()

# Collect comparisons
#  
Expand All @@ -306,11 +327,11 @@ binary (for recordings) and npz (for sortings).
# Note: use exhaustive_gt=True when you know exactly how many
# units in ground truth (for synthetic datasets)

# run all comparisons and loop over the results
study.run_comparisons(exhaustive_gt=True)

for (rec_name, sorter_name), comp in study.comparisons.items():
for key, comp in study.comparisons.items():
print('*' * 10)
print(rec_name, sorter_name)
print(key)
# raw counting of tp/fp/...
print(comp.count_score)
# summary
Expand All @@ -323,26 +344,27 @@ binary (for recordings) and npz (for sortings).

# Collect synthetic dataframes and display
# As shown previously, the performance is returned as a pandas dataframe.
# The :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function,
# The :py:func:`~spikeinterface.comparison.get_performance_by_unit()` function,
# gathers all the outputs in the study folder and merges them in a single dataframe.
# Same idea for :py:func:`~spikeinterface.comparison.get_count_units()`

dataframes = study.aggregate_dataframes()
# this is a dataframe
perfs = study.get_performance_by_unit()

# Pandas dataframes can be nicely displayed as tables in the notebook.
print(dataframes.keys())
# this is a dataframe
unit_counts = study.get_count_units()

# we can also access run times
print(dataframes['run_times'])
run_times = study.get_run_times()
print(run_times)

# Easy plot with seaborn
run_times = dataframes['run_times']
fig1, ax1 = plt.subplots()
sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1)
ax1.set_title('Run times')

##############################################################################

perfs = dataframes['perf_by_unit']
fig2, ax2 = plt.subplots()
sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2)
ax2.set_title('Recall')
Expand Down
9 changes: 4 additions & 5 deletions src/spikeinterface/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
compare_multiple_templates,
MultiTemplateComparison,
)
from .collisioncomparison import CollisionGTComparison
from .correlogramcomparison import CorrelogramGTComparison

from .groundtruthstudy import GroundTruthStudy
from .collisionstudy import CollisionGTStudy
from .correlogramstudy import CorrelogramGTStudy
from .studytools import aggregate_performances_table
from .collision import CollisionGTComparison, CollisionGTStudy
from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy

from .hybrid import (
HybridSpikesRecording,
HybridUnitsRecording,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np

from .paircomparisons import GroundTruthComparison
from .groundtruthstudy import GroundTruthStudy
from .comparisontools import make_collision_events

import numpy as np


class CollisionGTComparison(GroundTruthComparison):
"""
This class is an extension of GroundTruthComparison by focusing
to benchmark spike in collision
This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision.

This class needs maintenance and need a bit of refactoring.


collision_lag: float
Expand Down Expand Up @@ -156,3 +158,73 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good
pair_names = pair_names[order]

return similarities, recall_scores, pair_names


class CollisionGTStudy(GroundTruthStudy):
def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs):
_kwargs = dict()
_kwargs.update(kwargs)
_kwargs["exhaustive_gt"] = exhaustive_gt
_kwargs["collision_lag"] = collision_lag
_kwargs["nbins"] = nbins
GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs)
self.exhaustive_gt = exhaustive_gt
self.collision_lag = collision_lag

def get_lags(self, key):
comp = self.comparisons[key]
fs = comp.sorting1.get_sampling_frequency()
lags = comp.bins / fs * 1000.0
return lags

def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9):
import sklearn

if case_keys is None:
case_keys = self.cases.keys()

self.all_similarities = {}
self.all_recall_scores = {}
self.good_only = good_only

for key in case_keys:
templates = self.get_templates(key)
flat_templates = templates.reshape(templates.shape[0], -1)
similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates)
comp = self.comparisons[key]
similarities, recall_scores, pair_names = comp.compute_collision_by_similarity(
similarity, good_only=good_only, min_accuracy=min_accuracy
)
self.all_similarities[key] = similarities
self.all_recall_scores[key] = recall_scores

def get_mean_over_similarity_range(self, similarity_range, key):
idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1])
all_similarities = self.all_similarities[key][idx]
all_recall_scores = self.all_recall_scores[key][idx]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_recall_scores = all_recall_scores[order, :]

mean_recall_scores = np.nanmean(all_recall_scores, axis=0)

return mean_recall_scores

def get_lag_profile_over_similarity_bins(self, similarity_bins, key):
all_similarities = self.all_similarities[key]
all_recall_scores = self.all_recall_scores[key]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_recall_scores = all_recall_scores[order, :]

result = {}

for i in range(similarity_bins.size - 1):
cmin, cmax = similarity_bins[i], similarity_bins[i + 1]
amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0)
result[(cmin, cmax)] = mean_recall_scores

return result
88 changes: 0 additions & 88 deletions src/spikeinterface/comparison/collisionstudy.py

This file was deleted.

Loading