Skip to content

Commit

Permalink
Black formatting (#1633)
Browse files Browse the repository at this point in the history
* Add pre-commit-config file

* Black reformatting with line-length 120

---------

Co-authored-by: Heberto Mayorquin <[email protected]>
  • Loading branch information
alejoe91 and h-mayorquin authored May 15, 2023
1 parent 1c78add commit 7c6d356
Show file tree
Hide file tree
Showing 456 changed files with 19,328 additions and 15,806 deletions.
4 changes: 3 additions & 1 deletion src/spikeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
"""
import importlib.metadata

__version__ = importlib.metadata.version("spikeinterface")

from .core import *

import warnings
warnings.simplefilter('always', DeprecationWarning)

warnings.simplefilter("always", DeprecationWarning)

"""
submodules are imported only if needed
Expand Down
49 changes: 37 additions & 12 deletions src/spikeinterface/comparison/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
from .comparisontools import (count_matching_events, compute_agreement_score, count_match_spikes,
make_agreement_scores, make_possible_match, make_best_match, make_hungarian_match,
do_score_labels, compare_spike_trains, do_confusion_matrix, do_count_score,
compute_performance,
do_count_event, make_match_count_matrix)
from .paircomparisons import (compare_two_sorters, SymmetricSortingComparison,
compare_sorter_to_ground_truth, GroundTruthComparison,
compare_templates, TemplateComparison)
from .multicomparisons import (compare_multiple_sorters, MultiSortingComparison,
compare_multiple_templates, MultiTemplateComparison)
from .comparisontools import (
count_matching_events,
compute_agreement_score,
count_match_spikes,
make_agreement_scores,
make_possible_match,
make_best_match,
make_hungarian_match,
do_score_labels,
compare_spike_trains,
do_confusion_matrix,
do_count_score,
compute_performance,
do_count_event,
make_match_count_matrix,
)
from .paircomparisons import (
compare_two_sorters,
SymmetricSortingComparison,
compare_sorter_to_ground_truth,
GroundTruthComparison,
compare_templates,
TemplateComparison,
)
from .multicomparisons import (
compare_multiple_sorters,
MultiSortingComparison,
compare_multiple_templates,
MultiTemplateComparison,
)
from .collisioncomparison import CollisionGTComparison
from .correlogramcomparison import CorrelogramGTComparison
from .groundtruthstudy import GroundTruthStudy
from .collisionstudy import CollisionGTStudy
from .correlogramstudy import CorrelogramGTStudy
from .studytools import aggregate_performances_table
from .hybrid import (HybridSpikesRecording, HybridUnitsRecording, generate_injected_sorting,
create_hybrid_units_recording, create_hybrid_spikes_recording)
from .hybrid import (
HybridSpikesRecording,
HybridUnitsRecording,
generate_injected_sorting,
create_hybrid_units_recording,
create_hybrid_spikes_recording,
)
127 changes: 65 additions & 62 deletions src/spikeinterface/comparison/basecomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
from typing import OrderedDict
import numpy as np

from .comparisontools import (make_possible_match, make_best_match, make_hungarian_match)
from .comparisontools import make_possible_match, make_best_match, make_hungarian_match


class BaseComparison:
"""
Base class for all comparisons (SpikeTrain and Template)
"""

def __init__(self, object_list, name_list,
match_score=0.5, chance_score=0.1,
verbose=False):
def __init__(self, object_list, name_list, match_score=0.5, chance_score=0.1, verbose=False):
self.object_list = object_list
self.name_list = name_list
self._verbose = verbose
Expand All @@ -23,18 +21,21 @@ def __init__(self, object_list, name_list,
class BaseMultiComparison(BaseComparison):
"""
Base class for graph-based multi comparison classes.
It handles graph operations, comparisons, and agreements.
"""
def __init__(self, object_list, name_list,
match_score=0.5, chance_score=0.1,
verbose=False):

def __init__(self, object_list, name_list, match_score=0.5, chance_score=0.1, verbose=False):
import networkx as nx
BaseComparison.__init__(self, object_list=object_list,
name_list=name_list,
match_score=match_score,
chance_score=chance_score,
verbose=verbose)

BaseComparison.__init__(
self,
object_list=object_list,
name_list=name_list,
match_score=match_score,
chance_score=chance_score,
verbose=verbose,
)
self.graph = None
self.subgraphs = None
self.clean_graph = None
Expand Down Expand Up @@ -69,8 +70,9 @@ def compute_subgraphs(self):
g = self.clean_graph
else:
g = self.graph

import networkx as nx

subgraphs = (g.subgraph(c).copy() for c in nx.connected_components(g))
sg_object_names = []
sg_units = []
Expand All @@ -84,10 +86,12 @@ def compute_subgraphs(self):
sg_units.append(unit_names)
return sg_object_names, sg_units

def _do_comparison(self, ):
def _do_comparison(
self,
):
# do pairwise matching
if self._verbose:
print('Multicomaprison step 1: pairwise comparison')
print("Multicomaprison step 1: pairwise comparison")

self.comparisons = {}
for i in range(len(self.object_list)):
Expand All @@ -105,9 +109,10 @@ def _do_comparison(self, ):

def _do_graph(self):
if self._verbose:
print('Multicomparison step 2: make graph')
print("Multicomparison step 2: make graph")

import networkx as nx

self.graph = nx.Graph()
# nodes
self._populate_nodes()
Expand All @@ -128,11 +133,11 @@ def _do_graph(self):

def _clean_graph(self):
if self._verbose:
print('Multicomaprison step 3: clean graph')
print("Multicomaprison step 3: clean graph")
clean_graph = self.graph.copy()
import networkx as nx
subgraphs = (clean_graph.subgraph(c).copy()
for c in nx.connected_components(clean_graph))

subgraphs = (clean_graph.subgraph(c).copy() for c in nx.connected_components(clean_graph))
removed_nodes = 0
for sg in subgraphs:
object_names = []
Expand All @@ -150,18 +155,16 @@ def _clean_graph(self):
edges = sg.edges(n, data=True)
for e in edges:
edges_duplicates.append(e)
weights_duplicates.append(e[2]['weight'])
weights_duplicates.append(e[2]["weight"])

# remove extra edges
n_edges_to_remove = len(nodes_duplicate) - 1
remove_idxs = np.argsort(weights_duplicates)[:n_edges_to_remove]
edges_to_remove = np.array(edges_duplicates, dtype=object)[remove_idxs]

for edge_to_remove in edges_to_remove:
clean_graph.remove_edge(
edge_to_remove[0], edge_to_remove[1])
sg.remove_edge(
edge_to_remove[0], edge_to_remove[1])
clean_graph.remove_edge(edge_to_remove[0], edge_to_remove[1])
sg.remove_edge(edge_to_remove[0], edge_to_remove[1])
if self._verbose:
print(f"Removed edge: {edge_to_remove}")

Expand All @@ -177,24 +180,24 @@ def _clean_graph(self):
removed_nodes += 1

if self._verbose:
print(f'Removed {removed_nodes} duplicate nodes')
print(f"Removed {removed_nodes} duplicate nodes")
self.clean_graph = clean_graph

def _do_agreement(self):
# extract agreement from graph
if self._verbose:
print('Multicomparison step 4: extract agreement from graph')
print("Multicomparison step 4: extract agreement from graph")

self._new_units = {}

# save new units
import networkx as nx
self.subgraphs = [self.clean_graph.subgraph(c).copy()
for c in nx.connected_components(self.clean_graph)]

self.subgraphs = [self.clean_graph.subgraph(c).copy() for c in nx.connected_components(self.clean_graph)]
for new_unit, sg in enumerate(self.subgraphs):
edges = list(sg.edges(data=True))
if len(edges) > 0:
avg_agr = np.mean([d['weight'] for u, v, d in edges])
avg_agr = np.mean([d["weight"] for u, v, d in edges])
else:
avg_agr = 0
object_unit_ids = {}
Expand All @@ -206,32 +209,37 @@ def _do_agreement(self):
for name in self.name_list:
if name in object_unit_ids:
sorted_object_unit_ids[name] = object_unit_ids[name]
self._new_units[new_unit] = {'avg_agreement': avg_agr, 'unit_ids': sorted_object_unit_ids,
'agreement_number': len(sg.nodes)}
self._new_units[new_unit] = {
"avg_agreement": avg_agr,
"unit_ids": sorted_object_unit_ids,
"agreement_number": len(sg.nodes),
}


class BasePairComparison(BaseComparison):
"""
Base class for pair comparisons.
It handles the matching procedurs.
Agreement scores must be computed in inherited classes by overriding the
Agreement scores must be computed in inherited classes by overriding the
'_do_agreement(self)' function
"""
def __init__(self, object1, object2, name1, name2,
match_score=0.5, chance_score=0.1,
verbose=False):
BaseComparison.__init__(self, object_list=[object1, object2],
name_list=[name1, name2],
match_score=match_score,
chance_score=chance_score,
verbose=verbose)

def __init__(self, object1, object2, name1, name2, match_score=0.5, chance_score=0.1, verbose=False):
BaseComparison.__init__(
self,
object_list=[object1, object2],
name_list=[name1, name2],
match_score=match_score,
chance_score=chance_score,
verbose=verbose,
)
self.possible_match_12, self.possible_match_21 = None, None
self.best_match_12, self.best_match_21 = None, None
self.hungarian_match_12, self.hungarian_match_21 = None, None
self.agreement_scores = None

def _do_agreement(self):
# populate self.agreement_scores
raise NotImplementedError
Expand All @@ -240,13 +248,10 @@ def _do_matching(self):
if self._verbose:
print("Matching...")

self.possible_match_12, self.possible_match_21 = make_possible_match(
self.agreement_scores, self.chance_score)
self.best_match_12, self.best_match_21 = make_best_match(
self.agreement_scores, self.chance_score)
self.hungarian_match_12, self.hungarian_match_21 = make_hungarian_match(
self.agreement_scores, self.match_score)

self.possible_match_12, self.possible_match_21 = make_possible_match(self.agreement_scores, self.chance_score)
self.best_match_12, self.best_match_21 = make_best_match(self.agreement_scores, self.chance_score)
self.hungarian_match_12, self.hungarian_match_21 = make_hungarian_match(self.agreement_scores, self.match_score)

def get_ordered_agreement_scores(self):
assert self.agreement_scores is not None, "'agreement_scores' have not been computed!"
# order rows
Expand Down Expand Up @@ -275,34 +280,31 @@ class MixinSpikeTrainComparison:
* sampling frequency
* n_jobs
"""

def __init__(self, delta_time=0.4, n_jobs=-1):
self.delta_time = delta_time
self.n_jobs = n_jobs
self.sampling_frequency = None
self.delta_frames = None

def set_frames_and_frequency(self, sorting_list):
sorting0 = sorting_list[0]
# check num segments
if not np.all(sorting.get_num_segments() == sorting0.get_num_segments()
for sorting in sorting_list):
raise Exception('Sorting objects must have the same number of segments.')
if not np.all(sorting.get_num_segments() == sorting0.get_num_segments() for sorting in sorting_list):
raise Exception("Sorting objects must have the same number of segments.")

# take sampling frequency from sorting list and test that they are equivalent.
sampling_freqs = np.array([sorting.get_sampling_frequency()
for sorting in sorting_list], dtype='float64')
sampling_freqs = np.array([sorting.get_sampling_frequency() for sorting in sorting_list], dtype="float64")

# Some sorter round the sampling freq lets emit a warning
sf0 = sampling_freqs[0]
if not np.all(sf0 == sampling_freqs):
delta_freq_ratio = np.abs(sampling_freqs - sf0) / sf0
# tolerance of 0.1%
assert np.all(
delta_freq_ratio < 0.001), "Inconsistent sampling frequency among sorting list"
assert np.all(delta_freq_ratio < 0.001), "Inconsistent sampling frequency among sorting list"

self.sampling_frequency = sf0
self.delta_frames = int(
self.delta_time / 1000 * self.sampling_frequency)
self.delta_frames = int(self.delta_time / 1000 * self.sampling_frequency)


class MixinTemplateComparison:
Expand All @@ -311,6 +313,7 @@ class MixinTemplateComparison:
* similarity method
* sparsity
"""

def __init__(self, similarity_method="cosine_similarity", sparsity_dict=None):
self.similarity_method = similarity_method
self.sparsity_dict = sparsity_dict
Loading

0 comments on commit 7c6d356

Please sign in to comment.