Skip to content

Commit

Permalink
for progress_bar the for res in results need to be inside the with
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 8, 2024
1 parent e0ef39b commit a28c33d
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,15 @@ def run(self, recording_slices=None):
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(process_function_wrapper, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

elif self.pool_engine == "thread":
# only one shared context
Expand All @@ -440,19 +449,20 @@ def run(self, recording_slices=None):
) as executor:
results = executor.map(thread_func, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)


else:
raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'")


if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)



Expand Down

0 comments on commit a28c33d

Please sign in to comment.