Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into meta_merging
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Nov 4, 2024
2 parents 33ae9eb + c4d2eaa commit 2f80b8d
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 20 deletions.
22 changes: 11 additions & 11 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size):


def divide_recording_into_chunks(recording, chunk_size):
all_chunks = []
recording_slices = []
for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
chunks = divide_segment_into_chunks(num_frames, chunk_size)
all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
return all_chunks
recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
return recording_slices


def ensure_n_jobs(recording, n_jobs=1):
Expand Down Expand Up @@ -387,13 +387,13 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self, all_chunks=None):
def run(self, recording_slices=None):
"""
Runs the defined jobs.
"""

if all_chunks is None:
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
if recording_slices is None:
recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand All @@ -402,17 +402,17 @@ def run(self, all_chunks=None):

if self.n_jobs == 1:
if self.progress_bar:
all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name)
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)

worker_ctx = self.init_func(*self.init_args)
for segment_index, frame_start, frame_stop in all_chunks:
for segment_index, frame_start, frame_stop in recording_slices:
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))
n_jobs = min(self.n_jobs, len(recording_slices))

# parallel
with ProcessPoolExecutor(
Expand All @@ -421,10 +421,10 @@ def run(self, all_chunks=None):
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(function_wrapper, all_chunks)
results = executor.map(function_wrapper, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(all_chunks))
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def run_node_pipeline(
names=None,
verbose=False,
skip_after_n_peaks=None,
recording_slices=None,
):
"""
Machinery to compute in parallel operations on peaks and traces.
Expand Down Expand Up @@ -540,6 +541,10 @@ def run_node_pipeline(
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.
recording_slices : None | list[tuple]
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
It must be a list of (segment_index, frame_start, frame_stop).
If None (default), the function iterates over the entire duration of the recording.
Returns
-------
Expand Down Expand Up @@ -578,7 +583,7 @@ def run_node_pipeline(
**job_kwargs,
)

processor.run()
processor.run(recording_slices=recording_slices)

outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
return outs
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def append_noise_chunk(res):
gather_func=append_noise_chunk,
**job_kwargs,
)
executor.run(all_chunks=recording_slices)
executor.run(recording_slices=recording_slices)
noise_levels_chunks = np.stack(noise_levels_chunks)
noise_levels = np.mean(noise_levels_chunks, axis=0)

Expand Down
19 changes: 14 additions & 5 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil

from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording

from spikeinterface.core.job_tools import divide_recording_into_chunks

# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import (
Expand Down Expand Up @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation):
unpickled_node = pickle.loads(pickled_node)


def test_skip_after_n_peaks():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])
def test_skip_after_n_peaks_and_recording_slices():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205)

# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)
Expand All @@ -211,18 +211,27 @@ def test_skip_after_n_peaks():
node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True)
nodes = [node0, node1]

# skip
skip_after_n_peaks = 30
some_amplitudes = run_node_pipeline(
recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks
)

assert some_amplitudes.size >= skip_after_n_peaks
assert some_amplitudes.size < spikes.size

# slices : 1 every 4
recording_slices = divide_recording_into_chunks(recording, 10_000)
recording_slices = recording_slices[::4]
some_amplitudes = run_node_pipeline(
recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices
)
tolerance = 1.2
assert some_amplitudes.size < (spikes.size // 4) * tolerance


# the following is for testing locally with python or ipython. It is not used in ci or with pytest.
if __name__ == "__main__":
# folder = Path("./cache_folder/core")
# test_run_node_pipeline(folder)

test_skip_after_n_peaks()
test_skip_after_n_peaks_and_recording_slices()
9 changes: 7 additions & 2 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class WhitenRecording(BasePreprocessor):
recording : RecordingExtractor
The recording extractor to be whitened.
dtype : None or dtype, default: None
Datatype of the output recording (covariance matrix estimation
and whitening are performed in float32).
If None the the parent dtype is kept.
For integer dtype a int_scale must be also given.
mode : "global" | "local", default: "global"
Expand Down Expand Up @@ -74,7 +76,9 @@ def __init__(
dtype_ = fix_dtype(recording, dtype)

if dtype_.kind == "i":
assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale"
assert (
int_scale is not None
), "For recording with dtype=int you must set the output dtype to float OR set a int_scale"

if W is not None:
W = np.asarray(W)
Expand Down Expand Up @@ -124,7 +128,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale):
def get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))
traces_dtype = traces.dtype
# if uint --> force int
# if uint --> force float
if traces_dtype.kind == "u":
traces = traces.astype("float32")

Expand Down Expand Up @@ -185,6 +189,7 @@ def compute_whitening_matrix(
"""
random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs)
random_data = random_data.astype(np.float32)

regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"}

Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def detect_peaks(
folder=None,
names=None,
skip_after_n_peaks=None,
recording_slices=None,
**kwargs,
):
"""Peak detection based on threshold crossing in term of k x MAD.
Expand All @@ -83,6 +84,10 @@ def detect_peaks(
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.
recording_slices : None | list[tuple]
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
It must be a list of (segment_index, frame_start, frame_stop).
If None (default), the function iterates over the entire duration of the recording.
{method_doc}
{job_doc}
Expand Down Expand Up @@ -135,6 +140,7 @@ def detect_peaks(
folder=folder,
names=names,
skip_after_n_peaks=skip_after_n_peaks,
recording_slices=recording_slices,
)
return outs

Expand Down

0 comments on commit 2f80b8d

Please sign in to comment.