Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 24, 2024
1 parent 88d2dfd commit 023baba
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 30 deletions.
31 changes: 20 additions & 11 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy as np



class Tridesclous2Sorter(ComponentsBasedSorter):
sorter_name = "tridesclous2"

Expand Down Expand Up @@ -96,7 +95,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["progress_bar"] = verbose


recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

num_chans = recording_raw.get_num_channels()
Expand All @@ -111,25 +109,37 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
rec_for_motion = common_reference(rec_for_motion)
if verbose:
print("Start correct_motion()")
_, motion_info = correct_motion(rec_for_motion, folder=sorter_output_folder / "motion", output_motion_info=True,
**params["motion_correction"])
_, motion_info = correct_motion(
rec_for_motion,
folder=sorter_output_folder / "motion",
output_motion_info=True,
**params["motion_correction"],
)
if verbose:
print("Done correct_motion()")

recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32")
recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32")
recording = common_reference(recording)

if params["apply_motion_correction"]:
interpolate_motion_kwargs = dict(
direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
direction=1,
border_mode="force_extrapolate",
spatial_interpolation_method="kriging",
sigma_um=20.0,
p=2,
)

recording = InterpolateMotionRecording(
recording, motion_info["motion"], motion_info["temporal_bins"], motion_info["spatial_bins"], **interpolate_motion_kwargs
)
recording,
motion_info["motion"],
motion_info["temporal_bins"],
motion_info["spatial_bins"],
**interpolate_motion_kwargs,
)

recording = zscore(recording, dtype="float32")
recording = whiten(recording, dtype="float32", mode="local", radius_um=100.)
recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0)

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
Expand Down Expand Up @@ -186,7 +196,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0)
nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0)

templates_array = estimate_templates_with_accumulator(
recording_for_peeler,
sorting_pre_peeler.to_spike_vector(),
Expand Down Expand Up @@ -239,4 +249,3 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting = sorting.save(folder=sorter_output_folder / "sorting")

return sorting

Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)):
ax = axs[0, count]
ax.set_title(self.cases[key]["label"])
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)

return fig

def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
Expand Down Expand Up @@ -247,7 +247,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
fig.colorbar(im, ax=axs[0, count])
label = self.cases[key]["label"]
axs[0, count].set_title(label)

return fig

def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):
Expand Down Expand Up @@ -301,7 +301,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5
label = self.cases[key]["label"]
axs[0, count].set_title(label)
axs[0, count].legend()

return fig

def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):
Expand Down Expand Up @@ -361,7 +361,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
label = self.cases[key]["label"]
axs[0, count].set_title(label)
# axs[0, count].legend()

return fig

def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None):
Expand Down Expand Up @@ -492,7 +492,7 @@ def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units
figs.append(fig)
else:
print(key, "no overmerged")

return figs

def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
Expand Down Expand Up @@ -531,4 +531,4 @@ def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units
else:
print(key, "no over splited")

return figs
return figs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def plot_agreements(self, case_keys=None, figsize=None):
ax = axs[0, count]
ax.set_title(self.cases[key]["label"])
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)

return fig

def plot_performances_vs_snr(self, case_keys=None, figsize=None):
Expand All @@ -97,7 +97,7 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=None):

if count == 2:
ax.legend()

return fig

def plot_collisions(self, case_keys=None, figsize=None):
Expand All @@ -116,7 +116,7 @@ def plot_collisions(self, case_keys=None, figsize=None):
mode="lines",
good_only=False,
)

return fig

def plot_comparison_matching(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,13 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):
benchmark.save_run(bench_folder)
benchmark.result["run_time"] = float(t1 - t0)
benchmark.save_main(bench_folder)

def set_colors(self, colors=None, map_name="tab20"):
if colors is None:
case_keys = list(self.cases.keys())
self.colors = get_some_colors(case_keys, map_name=map_name,
color_engine = "matplotlib", shuffle=False, margin=0)
self.colors = get_some_colors(
case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0
)
else:
self.colors = colors

Expand Down Expand Up @@ -270,7 +271,7 @@ def plot_run_times(self, case_keys=None):
rt = run_times.at[key, "run_times"]
ax.bar(i, rt, width=0.8, color=colors[key])
ax.set_xticks(np.arange(len(case_keys)))
ax.set_xticklabels(labels, rotation=45.)
ax.set_xticklabels(labels, rotation=45.0)
return fig

# ax = run_times.plot(kind="bar")
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/sortingcomponents/clustering/tdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,19 @@ def main_function(cls, recording, peaks, params):
features_folder,
method="local_feature_clustering",
method_kwargs=dict(

clusterer="hdbscan",
clusterer_kwargs={"min_cluster_size": min_cluster_size, "allow_single_cluster": True, "cluster_selection_method": "eom"},

clusterer_kwargs={
"min_cluster_size": min_cluster_size,
"allow_single_cluster": True,
"cluster_selection_method": "eom",
},
# clusterer="isocut5",
# clusterer_kwargs={"min_cluster_size": min_cluster_size},

feature_name="sparse_tsvd",
# feature_name="sparse_wfs",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=min_cluster_size,

n_pca_features=3,
scale_n_pca_by_depth=True,
),
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
HAVE_MPL = False


def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None, margin=None):
def get_some_colors(
keys, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None, margin=None
):
"""
Return a dict of colors for given keys
Expand Down

0 comments on commit 023baba

Please sign in to comment.