Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add job_name to job_kwargs #1986

Closed
15 changes: 6 additions & 9 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,8 @@ def write_binary_recording(
func = _write_binary_chunk
init_func = _init_binary_worker
init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned)
executor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name="write_binary_recording", **job_kwargs
)
job_kwargs["job_name"] = "write_binary_recording"
executor = ChunkRecordingExecutor(recording, func, init_func, init_args, **job_kwargs)
executor.run()


Expand Down Expand Up @@ -477,9 +476,8 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
else:
init_args = (recording, arrays, None, None, dtype, cast_unsigned)

executor = ChunkRecordingExecutor(
recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs
)
job_kwargs["job_name"] = "write_memory_recording"
executor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=verbose, **job_kwargs)
executor.run()

return arrays
Expand Down Expand Up @@ -705,9 +703,8 @@ def write_traces_to_zarr(
func = _write_zarr_chunk
init_func = _init_zarr_worker
init_args = (recording, zarr_path, storage_options, dataset_paths, dtype, cast_unsigned)
executor = ChunkRecordingExecutor(
recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs
)
job_kwargs["job_name"] = "write_zarr_recording"
executor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=verbose, **job_kwargs)
executor.run()


Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"mp_context",
"verbose",
"max_threads_per_process",
"job_name",
)


Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ def run_node_pipeline(
recording,
nodes,
job_kwargs,
job_name="pipeline",
mp_context=None,
gather_mode="memory",
squeeze_output=True,
Expand All @@ -366,13 +365,14 @@ def run_node_pipeline(

init_args = (recording, nodes)

job_kwargs["job_name"] = "pipeline"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we want to modify in place a global job_kwargs dict we have in many script.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not modify the global job_kwargs here. It is just replicating the behavior that we had before.

Or maybe I am missing something?


processor = ChunkRecordingExecutor(
recording,
_compute_peak_pipeline_chunk,
_init_peak_pipeline,
init_args,
gather_func=gather_func,
job_name=job_name,
**job_kwargs,
)

Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __call__(self, res):
gather_func=gathering_func2,
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
)
processor.run()
num_chunks = len(divide_recording_into_chunks(recording, processor.chunk_size))
Expand All @@ -161,7 +160,6 @@ def __call__(self, res):
mp_context="spawn",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
)
processor.run()

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,9 @@ def distribute_waveforms_to_buffers(
sparsity_mask,
)
if job_name is None:
job_name = f"extract waveforms {mode} multi buffer"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
job_kwargs["job_name"] = f"extract waveforms {mode} multi buffer"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, **job_kwargs)
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, **job_kwargs)
processor.run()


Expand Down Expand Up @@ -517,7 +518,7 @@ def extract_waveforms_to_single_buffer(
if job_name is None:
job_name = f"extract waveforms {mode} mono buffer"

processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, **job_kwargs)
processor.run()

if mode == "memmap":
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ def _run(self, **job_kwargs):
handle_collisions,
delta_collision_samples,
)
job_kwargs["job_name"] = f"extract amplitude scalings"
processor = ChunkRecordingExecutor(
recording,
func,
init_func,
init_args,
handle_returns=True,
job_name="extract amplitude scalings",
**job_kwargs,
)
out = processor.run()
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs):
unit_channels,
pca_model,
)
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs)
job_kwargs["job_name"] = "extract PCs"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, **job_kwargs)
processor.run()

def _fit_by_channel_local(self, n_jobs, progress_bar):
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ def _run(self, **job_kwargs):
"`sorting.save()` function to make it dumpable"
)
init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs
)
job_kwargs["job_name"] = "extract amplitudes"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, **job_kwargs)
out = processor.run()
amps, segments = zip(*out)
amps = np.concatenate(amps)
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,13 @@ def correct_motion(
method_class = localize_peak_methods[method]
node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs)
pipeline_nodes = [node0, node1, node2]
job_kwargs["job_name"] = ("detect and localize",)

t0 = time.perf_counter()
peaks, peak_locations = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
job_name="detect and localize",
gather_mode=gather_mode,
squeeze_output=False,
folder=None,
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def compute_features_from_peaks(
node = Class(recording, parents=[peak_retriever, extract_dense_waveforms], **params)
nodes.append(node)

features = run_node_pipeline(recording, nodes, job_kwargs, job_name="features_from_peaks", squeeze_output=False)
job_kwargs["job_name"] = "features_from_peaks"
features = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=False)

return features

Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/sortingcomponents/matching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr
func = _find_spikes_chunk
init_func = _init_worker_find_spikes
init_args = (recording, method, method_kwargs_seralized)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", **job_kwargs
)
job_kwargs["job_name"] = f"find spikes ({method})"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, **job_kwargs)
spikes = processor.run()

spikes = np.concatenate(spikes)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def detect_peaks(
node.parents = [node0] + node.parents
nodes.append(node)

job_kwargs["job_name"] = job_name
outs = run_node_pipeline(
recording,
nodes,
job_kwargs,
job_name=job_name,
gather_mode=gather_mode,
squeeze_output=squeeze_output,
folder=folder,
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_
]

job_name = f"localize peaks using {method}"
peak_locations = run_node_pipeline(recording, pipeline_nodes, job_kwargs, job_name=job_name, squeeze_output=True)
job_kwargs["job_name"] = job_name
peak_locations = run_node_pipeline(recording, pipeline_nodes, job_kwargs, squeeze_output=True)

return peak_locations

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/sortingcomponents/peak_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def run_peak_pipeline(
peaks,
nodes,
job_kwargs,
job_name="peak_pipeline",
gather_mode="memory",
squeeze_output=True,
folder=None,
Expand All @@ -30,11 +29,11 @@ def run_peak_pipeline(
else:
node.parents = [node0] + node.parents
all_nodes = [node0] + nodes
job_kwargs["job_name"] = "peak pipeline"
outs = run_node_pipeline(
recording,
all_nodes,
job_kwargs,
job_name=job_name,
gather_mode=gather_mode,
squeeze_output=squeeze_output,
folder=folder,
Expand Down