Skip to content

Commit

Permalink
Merge pull request #1983 from samuelgarcia/gt_study
Browse files Browse the repository at this point in the history
Refactoring of GroundTruthStudy
  • Loading branch information
alejoe91 authored Sep 27, 2023
2 parents 3826a2d + cb9a228 commit 93f02e8
Show file tree
Hide file tree
Showing 19 changed files with 890 additions and 1,539 deletions.
100 changes: 61 additions & 39 deletions doc/modules/comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,21 +248,19 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for
We also have a high level class to compare many sorters against ground truth:
:py:func:`~spiekinterface.comparison.GroundTruthStudy()`

A study is a systematic performance comparison of several ground truth recordings with several sorters.
A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases
like the different parameter sets.

The study class proposes high-level tool functions to run many ground truth comparisons with many sorters
The study class proposes high-level tool functions to run many ground truth comparisons with many "cases"
on many recordings and then collect and aggregate results in an easy way.

The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolder:

* raw_files : contain a copy of recordings in binary format
* sorter_folders : contains outputs of sorters
* ground_truth : contains a copy of sorting ground truth in npz format
* sortings: contains light copy of all sorting in npz format
* tables: some tables in csv format

In order to run and rerun the computation all gt_sorting and recordings are copied to a fast and universal format:
binary (for recordings) and npz (for sortings).
* datasets: contains ground truth datasets
* sorters : contains outputs of sorters
* sortings: contains light copy of all sorting
* metrics: contains metrics
* ...


.. code-block:: python
Expand All @@ -274,28 +272,51 @@ binary (for recordings) and npz (for sortings).
import spikeinterface.widgets as sw
from spikeinterface.comparison import GroundTruthStudy
# Setup study folder
rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1)
rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1)
gt_dict = {
'rec0': (rec0, gt_sorting0),
'rec1': (rec1, gt_sorting1),
# generate 2 simulated datasets (could be also mearec files)
rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42)
rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91)
datasets = {
"toy0": (rec0, gt_sorting0),
"toy1": (rec1, gt_sorting1),
}
study_folder = 'a_study_folder'
study = GroundTruthStudy.create(study_folder, gt_dict)
# all sorters for all recordings in one function.
sorter_list = ['herdingspikes', 'tridesclous', ]
study.run_sorters(sorter_list, mode_if_folder_exists="keep")
# define some "cases" here we want to tests tridesclous2 on 2 datasets and spykingcircus on one dataset
# so it is a two level study (sorter_name, dataset)
# this could be more complicated like (sorter_name, dataset, params)
cases = {
("tdc2", "toy0"): {
"label": "tridesclous2 on tetrode0",
"dataset": "toy0",
"run_sorter_params": {
"sorter_name": "tridesclous2",
},
},
("tdc2", "toy1"): {
"label": "tridesclous2 on tetrode1",
"dataset": "toy1",
"run_sorter_params": {
"sorter_name": "tridesclous2",
},
},
("sc", "toy0"): {
"label": "spykingcircus2 on tetrode0",
"dataset": "toy0",
"run_sorter_params": {
"sorter_name": "spykingcircus",
"docker_image": True
},
},
}
# this initilize a folder
study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases,
levels=["sorter_name", "dataset"])
# You can re-run **run_study_sorters** as many times as you want.
# By default **mode='keep'** so only uncomputed sorters are re-run.
# For instance, just remove the "sorter_folders/rec1/herdingspikes" to re-run
# only one sorter on one recording.
#
# Then we copy the spike sorting outputs into a separate subfolder.
# This allow us to remove the "large" sorter_folders.
study.copy_sortings()
# all cases in one function
study.run_sorters()
# Collect comparisons
#  
Expand All @@ -306,11 +327,11 @@ binary (for recordings) and npz (for sortings).
# Note: use exhaustive_gt=True when you know exactly how many
# units in ground truth (for synthetic datasets)
# run all comparisons and loop over the results
study.run_comparisons(exhaustive_gt=True)
for (rec_name, sorter_name), comp in study.comparisons.items():
for key, comp in study.comparisons.items():
print('*' * 10)
print(rec_name, sorter_name)
print(key)
# raw counting of tp/fp/...
print(comp.count_score)
# summary
Expand All @@ -323,26 +344,27 @@ binary (for recordings) and npz (for sortings).
# Collect synthetic dataframes and display
# As shown previously, the performance is returned as a pandas dataframe.
# The :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function,
# The :py:func:`~spikeinterface.comparison.get_performance_by_unit()` function,
# gathers all the outputs in the study folder and merges them in a single dataframe.
# Same idea for :py:func:`~spikeinterface.comparison.get_count_units()`
dataframes = study.aggregate_dataframes()
# this is a dataframe
perfs = study.get_performance_by_unit()
# Pandas dataframes can be nicely displayed as tables in the notebook.
print(dataframes.keys())
# this is a dataframe
unit_counts = study.get_count_units()
# we can also access run times
print(dataframes['run_times'])
run_times = study.get_run_times()
print(run_times)
# Easy plot with seaborn
run_times = dataframes['run_times']
fig1, ax1 = plt.subplots()
sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1)
ax1.set_title('Run times')
##############################################################################
perfs = dataframes['perf_by_unit']
fig2, ax2 = plt.subplots()
sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2)
ax2.set_title('Recall')
Expand Down
9 changes: 4 additions & 5 deletions src/spikeinterface/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
compare_multiple_templates,
MultiTemplateComparison,
)
from .collisioncomparison import CollisionGTComparison
from .correlogramcomparison import CorrelogramGTComparison

from .groundtruthstudy import GroundTruthStudy
from .collisionstudy import CollisionGTStudy
from .correlogramstudy import CorrelogramGTStudy
from .studytools import aggregate_performances_table
from .collision import CollisionGTComparison, CollisionGTStudy
from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy

from .hybrid import (
HybridSpikesRecording,
HybridUnitsRecording,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np

from .paircomparisons import GroundTruthComparison
from .groundtruthstudy import GroundTruthStudy
from .comparisontools import make_collision_events

import numpy as np


class CollisionGTComparison(GroundTruthComparison):
"""
This class is an extension of GroundTruthComparison by focusing
to benchmark spike in collision
This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision.
This class needs maintenance and need a bit of refactoring.
collision_lag: float
Expand Down Expand Up @@ -156,3 +158,73 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good
pair_names = pair_names[order]

return similarities, recall_scores, pair_names


class CollisionGTStudy(GroundTruthStudy):
def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs):
_kwargs = dict()
_kwargs.update(kwargs)
_kwargs["exhaustive_gt"] = exhaustive_gt
_kwargs["collision_lag"] = collision_lag
_kwargs["nbins"] = nbins
GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs)
self.exhaustive_gt = exhaustive_gt
self.collision_lag = collision_lag

def get_lags(self, key):
comp = self.comparisons[key]
fs = comp.sorting1.get_sampling_frequency()
lags = comp.bins / fs * 1000.0
return lags

def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9):
import sklearn

if case_keys is None:
case_keys = self.cases.keys()

self.all_similarities = {}
self.all_recall_scores = {}
self.good_only = good_only

for key in case_keys:
templates = self.get_templates(key)
flat_templates = templates.reshape(templates.shape[0], -1)
similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates)
comp = self.comparisons[key]
similarities, recall_scores, pair_names = comp.compute_collision_by_similarity(
similarity, good_only=good_only, min_accuracy=min_accuracy
)
self.all_similarities[key] = similarities
self.all_recall_scores[key] = recall_scores

def get_mean_over_similarity_range(self, similarity_range, key):
idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1])
all_similarities = self.all_similarities[key][idx]
all_recall_scores = self.all_recall_scores[key][idx]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_recall_scores = all_recall_scores[order, :]

mean_recall_scores = np.nanmean(all_recall_scores, axis=0)

return mean_recall_scores

def get_lag_profile_over_similarity_bins(self, similarity_bins, key):
all_similarities = self.all_similarities[key]
all_recall_scores = self.all_recall_scores[key]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_recall_scores = all_recall_scores[order, :]

result = {}

for i in range(similarity_bins.size - 1):
cmin, cmax = similarity_bins[i], similarity_bins[i + 1]
amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0)
result[(cmin, cmax)] = mean_recall_scores

return result
88 changes: 0 additions & 88 deletions src/spikeinterface/comparison/collisionstudy.py

This file was deleted.

Loading

0 comments on commit 93f02e8

Please sign in to comment.