From a28c33d5af6f153f1bdfbe2998959ee2139ed250 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Nov 2024 13:12:01 +0100 Subject: [PATCH] for progress_bar the for res in results need to be inside the with --- src/spikeinterface/core/job_tools.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 70a4fe2345..4e0819d0d9 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -428,6 +428,15 @@ def run(self, recording_slices=None): initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: results = executor.map(process_function_wrapper, recording_slices) + + if self.progress_bar: + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) elif self.pool_engine == "thread": # only one shared context @@ -440,19 +449,20 @@ def run(self, recording_slices=None): ) as executor: results = executor.map(thread_func, recording_slices) + if self.progress_bar: + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + else: raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") - if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(recording_slices)) - - for res in results: - if self.handle_returns: - returns.append(res) - if self.gather_func is not None: - self.gather_func(res)