Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc2 fix #2058

Merged
merged 3 commits into from
Oct 2, 2023
Merged

Sc2 fix #2058

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True

if sorting_exists:
# delete older sorting + log before running sorters
shutil.rmtree(sorting_exists)
shutil.rmtree(sorting_folder)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
if log_file.exists():
log_file.unlink()
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if clustering_folder.exists():
shutil.rmtree(clustering_folder)

sorting = sorting.save(folder=clustering_folder)

## We get the templates our of such a clustering
waveforms_params = params["waveforms"].copy()
waveforms_params.update(job_kwargs)
Expand All @@ -131,6 +129,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
mode = "memory"
waveforms_folder = None
else:
sorting = sorting.save(folder=clustering_folder)
mode = "folder"
waveforms_folder = sorter_output_folder / "waveforms"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ def remove_duplicates_via_matching(
if tmp_folder is None:
tmp_folder = get_global_tmp_folder()

tmp_folder.mkdir(parents=True, exist_ok=True)

tmp_filename = tmp_folder / "tmp.raw"

f = open(tmp_filename, "wb")
Expand All @@ -583,8 +585,8 @@ def remove_duplicates_via_matching(
f.close()

recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32")
recording.annotate(is_filtered=True)
recording = recording.set_probe(waveform_extractor.recording.get_probe())
recording.annotate(is_filtered=True)

margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter)
half_marging = margin // 2
Expand All @@ -608,7 +610,6 @@ def remove_duplicates_via_matching(
t_stop = padding + (i + 1) * duration

sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging)

method_kwargs.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
Expand Down Expand Up @@ -660,7 +661,7 @@ def remove_duplicates_via_matching(
labels = np.unique(new_labels)
labels = labels[labels >= 0]

del recording, sub_recording
del recording, sub_recording, method_kwargs
os.remove(tmp_filename)

return labels, new_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,20 @@ def sigmoid(x, L, x0, k, b):
else:
tmp_folder = Path(params["tmp_folder"])

tmp_folder.mkdir(parents=True, exist_ok=True)

sorting_folder = tmp_folder / "sorting"
unit_ids = np.arange(len(np.unique(spikes["unit_index"])))
sorting = NumpySorting(spikes, fs, unit_ids=unit_ids)

if params["shared_memory"]:
waveform_folder = None
mode = "memory"
else:
waveform_folder = tmp_folder / "waveforms"
mode = "folder"
sorting = sorting.save(folder=sorting_folder)

sorting_folder = tmp_folder / "sorting"
unit_ids = np.arange(len(np.unique(spikes["unit_index"])))
sorting = NumpySorting(spikes, fs, unit_ids=unit_ids)
sorting = sorting.save(folder=sorting_folder)
we = extract_waveforms(
recording,
sorting,
Expand Down Expand Up @@ -219,12 +222,14 @@ def sigmoid(x, L, x0, k, b):
we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
)

del we, sorting

if params["tmp_folder"] is None:
shutil.rmtree(tmp_folder)
else:
if not params["shared_memory"]:
shutil.rmtree(tmp_folder / "waveforms")
shutil.rmtree(tmp_folder / "sorting")
shutil.rmtree(tmp_folder / "sorting")

if verbose:
print("We kept %d non-duplicated clusters..." % len(labels))
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def _prepare_templates(cls, d):
d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2])
d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0])
d["singular"] = d["singular"].T[:, :, np.newaxis]

return d

@classmethod
Expand Down