Skip to content

Commit

Permalink
more fix after merge with main and the new pickle to file mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 27, 2023
1 parent c0c2163 commit 6c561f2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
21 changes: 12 additions & 9 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True
sorter_name = params.pop("sorter_name")
job = dict(sorter_name=sorter_name,
recording=recording,
output_folder=sorter_folder)
output_folder=sorter_folder,
)
job.update(params)
# the verbose is overwritten and global to all run_sorters
job["verbose"] = verbose
job["with_output"] = False
job_list.append(job)

run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False)
Expand All @@ -217,7 +219,8 @@ def copy_sortings(self, case_keys=None, force=True):


if (sorter_folder / "spikeinterface_log.json").exists():
sorting = read_sorter_folder(sorter_folder, raise_error=False)
sorting = read_sorter_folder(sorter_folder, raise_error=False,
register_recording=False, sorting_info=False)
else:
sorting = None

Expand Down Expand Up @@ -383,13 +386,12 @@ def get_count_units(
index = pd.MultiIndex.from_tuples(case_keys, names=self.levels)


columns = ["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"]
columns = ["num_gt", "num_sorter", "num_well_detected"]
comp = self.comparisons[case_keys[0]]
if comp.exhaustive_gt:
columns.extend(["num_false_positive", "num_bad"])
columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"])
count_units = pd.DataFrame(index=index, columns=columns, dtype=int)


for key in case_keys:
comp = self.comparisons.get(key, None)
assert comp is not None, "You need to do study.run_comparisons() first"
Expand All @@ -402,11 +404,12 @@ def get_count_units(
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(
well_detected_score
)
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(
overmerged_score
)
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)

if comp.exhaustive_gt:
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(
overmerged_score
)
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(
redundant_score
)
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False):
recording = None
else:
recording = load_extractor(json_file, base_folder=output_folder)
elif pickle_file.exits():
elif pickle_file.exists():
recording = load_extractor(pickle_file)

return recording
Expand Down Expand Up @@ -324,8 +324,12 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_

if sorting_info:
# set sorting info to Sorting object
with open(output_folder / "spikeinterface_recording.json", "r") as f:
rec_dict = json.load(f)
if (output_folder / "spikeinterface_recording.json").exists():
with open(output_folder / "spikeinterface_recording.json", "r") as f:
rec_dict = json.load(f)
else:
rec_dict = None

with open(output_folder / "spikeinterface_params.json", "r") as f:
params_dict = json.load(f)
with open(output_folder / "spikeinterface_log.json", "r") as f:
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/sorters/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
engine_kwargs: dict
return_output: bool, dfault False
Return a sorting or None.
Return a sortings or None.
This also overwrite kwargs in in run_sorter(with_sorting=True/False)
Returns
-------
Expand All @@ -88,8 +89,13 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
"processpoolexecutor",
), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True."
out = []
for kwargs in job_list:
kwargs['with_output'] = True
else:
out = None
for kwargs in job_list:
kwargs['with_output'] = False


if engine == "loop":
# simple loop in main process
Expand Down

0 comments on commit 6c561f2

Please sign in to comment.