diff --git a/spiketoolkit/postprocessing/postprocessing_tools.py b/spiketoolkit/postprocessing/postprocessing_tools.py index fec41ce1..2c8422b4 100644 --- a/spiketoolkit/postprocessing/postprocessing_tools.py +++ b/spiketoolkit/postprocessing/postprocessing_tools.py @@ -124,17 +124,16 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, retu for unit_id in unit_ids: waveforms = sorting.get_unit_spike_features(unit_id, 'waveforms') waveform_list.append(waveforms) - if return_idxs: - if len(waveforms) < len(sorting.get_unit_spike_train(unit_id)): - indexes = sorting.get_unit_spike_features(unit_id, 'waveforms_idxs') - else: - indexes = np.arange(len(waveforms)) - if 'waveforms_channel_idxs' in sorting.get_shared_unit_property_names(): - channel_idxs = sorting.get_unit_property(unit_id, 'waveforms_channel_idxs') - else: - channel_idxs = np.arange(recording.get_num_channels()) - spike_index_list.append(indexes) - channel_index_list.append(channel_idxs) + if len(waveforms) < len(sorting.get_unit_spike_train(unit_id)): + indexes = sorting.get_unit_spike_features(unit_id, 'waveforms_idxs') + else: + indexes = np.arange(len(waveforms)) + if 'waveforms_channel_idxs' in sorting.get_shared_unit_property_names(): + channel_idxs = sorting.get_unit_property(unit_id, 'waveforms_channel_idxs') + else: + channel_idxs = np.arange(recording.get_num_channels()) + spike_index_list.append(indexes) + channel_index_list.append(channel_idxs) else: if dtype is None: dtype = recording.get_dtype() @@ -647,9 +646,9 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret recompute_info: bool If True, waveforms are recomputed (default False) n_jobs: int - Number of jobs for parallelization. Default is None (no parallelization). + Number of jobs for parallelization. Default is None (no parallelization) joblib_backend: str - The backend for joblib. Default is 'loky'. + The backend for joblib. Default is 'loky' verbose: bool If True output is verbose @@ -689,12 +688,11 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret for unit_id in unit_ids: amplitudes = sorting.get_unit_spike_features(unit_id, 'amplitudes') amp_list.append(amplitudes) - if return_idxs: - if len(amplitudes) < len(sorting.get_unit_spike_train(unit_id)): - indexes = sorting.get_unit_spike_features(unit_id, 'amplitudes_idxs') - else: - indexes = np.arange(len(amplitudes)) - spike_index_list.append(indexes) + if len(amplitudes) < len(sorting.get_unit_spike_train(unit_id)): + indexes = sorting.get_unit_spike_features(unit_id, 'amplitudes_idxs') + else: + indexes = np.arange(len(amplitudes)) + spike_index_list.append(indexes) else: # pre-construct memmap arrays if memmap: @@ -740,9 +738,9 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret else: amp_list[i] = amps - if save_property_or_features: - for i, unit_id in enumerate(unit_ids): - sorting.set_unit_spike_features(unit_id, 'amplitudes', amp_list[i], indexes=spike_index_list[i]) + if save_property_or_features: + for i, unit_id in enumerate(unit_ids): + sorting.set_unit_spike_features(unit_id, 'amplitudes', amp_list[i], indexes=spike_index_list[i]) if return_idxs: return amp_list, spike_index_list @@ -809,44 +807,12 @@ def compute_channel_spiking_activity(recording, channel_ids=None, detect_thresho start_frame = 0 if end_frame is None: end_frame = recording.get_num_frames() - duration = (end_frame - start_frame) / recording.get_sampling_frequency() assert np.all([ch in recording.get_channel_ids() for ch in channel_ids]), "Invalid channel_ids" spike_rates = np.zeros(len(channel_ids)) spike_amplitudes = np.zeros(len(channel_ids)) - if n_jobs is None: - n_jobs = 1 - if n_jobs == 0: - n_jobs = 1 - - if start_frame != 0 or end_frame != recording.get_num_frames(): - recording_sub = se.SubRecordingExtractor(recording, start_frame=start_frame, end_frame=end_frame) - else: - recording_sub = recording - - num_frames = recording_sub.get_num_frames() - - # set chunk size - if chunk_size is not None: - chunk_size = int(chunk_size) - elif chunk_mb is not None: - n_bytes = np.dtype(recording.get_dtype()).itemsize - max_size = int(chunk_mb * 1e6) # set Mb per chunk - chunk_size = max_size // (recording.get_num_channels() * n_bytes) - - if n_jobs > 1: - chunk_size /= n_jobs - - # chunk_size = num_bytes_per_chunk / num_bytes_per_frame - chunks = divide_recording_into_time_chunks( - num_frames=num_frames, - chunk_size=chunk_size, - padding_size=0 - ) - n_chunk = len(chunks) - if 'spike_rate' in recording.get_shared_channel_property_names() and \ 'spike_amplitude' in recording.get_shared_channel_property_names() and not recompute_info: for i, ch in enumerate(recording.get_channel_ids()): @@ -855,7 +821,8 @@ def compute_channel_spiking_activity(recording, channel_ids=None, detect_thresho else: sort_detect = st.sortingcomponents.detect_spikes(recording, channel_ids=channel_ids, detect_threshold=detect_threshold, detect_sign=detect_sign, - n_jobs=n_jobs, start_frame=start_frame, end_frame=end_frame, + n_jobs=n_jobs, joblib_backend=joblib_backend, + start_frame=start_frame, end_frame=end_frame, verbose=verbose) for i, unit in enumerate(sort_detect.get_unit_ids()): @@ -1058,17 +1025,16 @@ def compute_unit_pca_scores(recording, sorting, unit_ids=None, channel_ids=None, for unit_id in unit_ids: pca_scores = sorting.get_unit_spike_features(unit_id, 'pca_scores') pca_scores_list.append(pca_scores) - if return_idxs: - if len(pca_scores) < len(sorting.get_unit_spike_train(unit_id)): - indexes = sorting.get_unit_spike_features(unit_id, 'pca_scores_idxs') - else: - indexes = np.arange(len(pca_scores)) - if 'pca_scores_channel_idxs' in sorting.get_shared_unit_property_names() and by_electrode: - channel_idxs = sorting.get_unit_property(unit_id, 'pca_scores_channel_idxs') - else: - channel_idxs = np.arange(recording.get_num_channels()) - spike_index_list.append(indexes) - channel_index_list.append(channel_idxs) + if len(pca_scores) < len(sorting.get_unit_spike_train(unit_id)): + indexes = sorting.get_unit_spike_features(unit_id, 'pca_scores_idxs') + else: + indexes = np.arange(len(pca_scores)) + if 'pca_scores_channel_idxs' in sorting.get_shared_unit_property_names() and by_electrode: + channel_idxs = sorting.get_unit_property(unit_id, 'pca_scores_channel_idxs') + else: + channel_idxs = np.arange(recording.get_num_channels()) + spike_index_list.append(indexes) + channel_index_list.append(channel_idxs) else: if max_spikes_for_pca is None: max_spikes_for_pca = np.inf diff --git a/spiketoolkit/sortingcomponents/detection.py b/spiketoolkit/sortingcomponents/detection.py index 2010b907..62c0e4ba 100644 --- a/spiketoolkit/sortingcomponents/detection.py +++ b/spiketoolkit/sortingcomponents/detection.py @@ -8,7 +8,7 @@ def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=-1, n_shifts=2, n_snippets_for_threshold=10, snippet_size_sec=1, - start_frame=None, end_frame=None, n_jobs=1, + start_frame=None, end_frame=None, n_jobs=1, joblib_backend='loky', chunk_size=None, chunk_mb=500, verbose=False): ''' Detects spikes per channel. Spikes are detected as threshold crossings and the threshold is in terms of the median @@ -38,13 +38,15 @@ def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=- end_frame: int End frame end frame for detection n_jobs: int - Number of jobs when parallel + Number of jobs for parallelization. Default is None (no parallelization) + joblib_backend: str + The backend for joblib. Default is 'loky' chunk_size: int Size of chunks in number of samples. If None, it is automatically calculated chunk_mb: int Size of chunks in Mb (default 500 Mb) verbose: bool - If True output is verbose + If True output is verbose Returns ------- @@ -124,10 +126,11 @@ def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=- thresholds = detect_threshold * np.median(np.abs(traces_mad) / 0.6745, 1)[:, None] if n_jobs > 1: - output = Parallel(n_jobs=n_jobs)(delayed(_detect_and_align_peaks_chunk) - (ii, rec_arg, chunks, channel_ids, thresholds, detect_sign, - n_shifts, verbose) - for ii in chunk_iter) + output = Parallel(n_jobs=n_jobs, backend=joblib_backend)(delayed(_detect_and_align_peaks_chunk) + (ii, rec_arg, chunks, channel_ids, thresholds, + detect_sign, + n_shifts, verbose) + for ii in chunk_iter) for ii, (times_ii, amps_ii) in enumerate(output): for i, ch in enumerate(channel_ids): times = times_ii[i]