Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 5, 2024
1 parent 91dfd38 commit ecb7e3e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


# This is to separate names when the key are tuples when saving folders
# _key_separator = "_##_"
# _key_separator = "_##_"
_key_separator = "_-°°-_"


Expand Down
32 changes: 13 additions & 19 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"threshold_diff": 1.5,
},
"templates": {
"ms_before": 2.,
"ms_after": 3.,
"max_spikes_per_unit" : 400,
"ms_before": 2.0,
"ms_after": 3.0,
"max_spikes_per_unit": 400,
# "peak_shift_ms": 0.2,
},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
"matching": {"method": "circus-omp-svd", "method_kwargs": {}},

"job_kwargs": {"n_jobs": -1},
"save_array": True,
}
Expand All @@ -63,16 +62,16 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"apply_preprocessing": "Apply internal preprocessing or not",
"cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ",
"waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um",
"filtering": "A dictonary containing filtering params: freq_min, freq_max",
"detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um",
"selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks",
"svd": "A dictonary containing svd params: n_components",
"filtering": "A dictonary containing filtering params: freq_min, freq_max",
"detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um",
"selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks",
"svd": "A dictonary containing svd params: n_components",
"clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um",
"templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after",
"matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um",
"job_kwargs": "A dictionary containing job kwargs",
"save_array": "Save or not intermediate arrays",
}
}

handle_multi_segment = True

Expand Down Expand Up @@ -118,10 +117,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording = common_reference(recording)
recording = zscore(recording, dtype="float32")
# recording = whiten(recording, dtype="float32")

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
recording = cache_preprocessing(recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"])
recording = cache_preprocessing(
recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"]
)

noise_levels = np.ones(num_chans, dtype="float32")
else:
Expand Down Expand Up @@ -243,7 +244,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

merge_radius_um = params["clustering"]["merge_radius_um"]
threshold_diff = params["clustering"]["threshold_diff"]


post_merge_label, peak_shifts = merge_clusters(
peaks,
Expand Down Expand Up @@ -297,17 +297,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)
sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp")

we = extract_waveforms(
recording,
sorting_temp,
sorter_output_folder / "waveforms_temp",
**params["templates"])
we = extract_waveforms(recording, sorting_temp, sorter_output_folder / "waveforms_temp", **params["templates"])

# snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum")
# print(snrs)



# matching_params = params["matching"].copy()
# matching_params["waveform_extractor"] = we
# matching_params["noise_levels"] = noise_levels
Expand Down
21 changes: 7 additions & 14 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,10 +660,10 @@ def merge(
for shift in range(-num_shift, num_shift + 1):
temp0 = template0[num_shift : num_samples - num_shift, :]
temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :]
#d = np.mean(np.abs(temp0 - temp1)) / (norm)
# d = np.mean(np.abs(temp0 - temp1)) / (norm)
# d = np.max(np.abs(temp0 - temp1)) / (norm)
diff_per_channel = np.abs(temp0 - temp1) / norm

diff_max = np.max(diff_per_channel, axis=0)

# diff = np.max(diff_per_channel)
Expand All @@ -675,7 +675,6 @@ def merge(
# d = np.mean(diff_by_channel)
# all_shift_diff.append(d)
normed_diff = np.min(all_shift_diff)


is_merge = normed_diff < threshold_diff

Expand All @@ -687,7 +686,6 @@ def merge(
else:
final_shift = 0
merge_value = np.nan


# print('merge_value', merge_value, 'final_shift', final_shift, 'is_merge', is_merge)

Expand All @@ -696,7 +694,7 @@ def merge(
# if DEBUG and ( 0. < normed_diff < .4):
# if 0.5 < normed_diff < 4:
if DEBUG and is_merge:
# if DEBUG:
# if DEBUG:

import matplotlib.pyplot as plt

Expand All @@ -711,8 +709,6 @@ def merge(
m0 = temp0.T.flatten()
m1 = temp1.T.flatten()



ax = axs[0]
ax.plot(m0, color="C0", label=f"{label0} {inds0.size}")
ax.plot(m1, color="C1", label=f"{label1} {inds1.size}")
Expand All @@ -724,16 +720,15 @@ def merge(

ax = axs[1]

#~ temp0 = template0[num_shift : num_samples - num_shift, :]
#~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :]
# ~ temp0 = template0[num_shift : num_samples - num_shift, :]
# ~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :]
ax.plot(np.abs(m0 - m1))
# ax.axhline(norm, ls='--', color='k')
ax = axs[2]
ax.plot(diff_per_channel.T.flatten())
ax.axhline(threshold_diff, ls='--')
ax.axhline(threshold_diff, ls="--")
ax.axhline(normed_diff)



# ax.axhline(normed_diff, ls='-', color='b')
# ax.plot(norm, ls='--')
# ax.plot(diff_by_channel)
Expand All @@ -744,8 +739,6 @@ def merge(

# ax.set_title(f"{norm=:.3f}")



plt.show()

return is_merge, label0, label1, final_shift, merge_value
Expand Down

0 comments on commit ecb7e3e

Please sign in to comment.