Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed May 22, 2024
1 parent 9a03744 commit ef0bba9
Showing 1 changed file with 6 additions and 39 deletions.
45 changes: 6 additions & 39 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
num_chans = recording_raw.get_num_channels()
sampling_frequency = recording_raw.get_sampling_frequency()



# preprocessing
if params["apply_preprocessing"]:
if params["apply_motion_correction"]:
Expand Down Expand Up @@ -183,25 +181,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_kwargs["waveforms"] = params["waveforms"].copy()
clustering_kwargs["clustering"] = params["clustering"].copy()

labels_set, post_clean_label, extra_out = find_cluster_from_peaks(
labels_set, clustering_label, extra_out = find_cluster_from_peaks(
recording, peaks, method="tdc_clustering", method_kwargs=clustering_kwargs, extra_outputs=True, **job_kwargs
)
peak_shifts = extra_out["peak_shifts"]
new_peaks = peaks.copy()
new_peaks["sample_index"] -= peak_shifts

mask = post_clean_label >= 0
mask = clustering_label >= 0
sorting_pre_peeler = NumpySorting.from_times_labels(
new_peaks["sample_index"][mask],
post_clean_label[mask],
clustering_label[mask],
sampling_frequency,
unit_ids=labels_set,
)

if verbose:
print(f"find_cluster_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found")

# recording_w = whiten(recording, mode="local", radius_um=100.0)
recording_for_peeler = recording

nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0)
Expand Down Expand Up @@ -232,40 +229,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates = templates_dense.to_sparse(sparsity)
templates = remove_empty_templates(templates)

# snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum")
# print(snrs)

# matching_params = params["matching"].copy()
# matching_params["noise_levels"] = noise_levels
# matching_params["peak_sign"] = params["detection"]["peak_sign"]
# matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
# matching_params["radius_um"] = params["detection"]["radius_um"]

# spikes = find_spikes_from_templates(
# recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
# )

## peeler
matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()

matching_params["templates"] = templates
matching_params["noise_levels"] = noise_levels
# matching_params["peak_sign"] = params["detection"]["peak_sign"]
# matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
# matching_params["radius_um"] = params["detection"]["radius_um"]

# spikes = find_spikes_from_templates(
# recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
# )
# )

# if matching_method == "circus-omp-svd":
# job_kwargs = job_kwargs.copy()
# for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
# if value in job_kwargs:
# job_kwargs.pop(value)
# job_kwargs["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
)
Expand All @@ -275,9 +243,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

np.save(sorter_output_folder / "noise_levels.npy", noise_levels)
np.save(sorter_output_folder / "all_peaks.npy", all_peaks)
# np.save(sorter_output_folder / "post_split_label.npy", post_split_label)
# np.save(sorter_output_folder / "split_count.npy", split_count)
# np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label)
np.save(sorter_output_folder / "peaks.npy", peaks)
np.save(sorter_output_folder / "clustering_label.npy", clustering_label)
np.save(sorter_output_folder / "spikes.npy", spikes)

final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype)
Expand Down

0 comments on commit ef0bba9

Please sign in to comment.