Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #422 from SpikeInterface/small_posptrocessing_fixes
Browse files Browse the repository at this point in the history
Fixed small bugs and args in postprocessing and detection
  • Loading branch information
alejoe91 authored Dec 9, 2020
2 parents 04ebb6b + 18eb821 commit 97d4ad2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 73 deletions.
98 changes: 32 additions & 66 deletions spiketoolkit/postprocessing/postprocessing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,16 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, retu
for unit_id in unit_ids:
waveforms = sorting.get_unit_spike_features(unit_id, 'waveforms')
waveform_list.append(waveforms)
if return_idxs:
if len(waveforms) < len(sorting.get_unit_spike_train(unit_id)):
indexes = sorting.get_unit_spike_features(unit_id, 'waveforms_idxs')
else:
indexes = np.arange(len(waveforms))
if 'waveforms_channel_idxs' in sorting.get_shared_unit_property_names():
channel_idxs = sorting.get_unit_property(unit_id, 'waveforms_channel_idxs')
else:
channel_idxs = np.arange(recording.get_num_channels())
spike_index_list.append(indexes)
channel_index_list.append(channel_idxs)
if len(waveforms) < len(sorting.get_unit_spike_train(unit_id)):
indexes = sorting.get_unit_spike_features(unit_id, 'waveforms_idxs')
else:
indexes = np.arange(len(waveforms))
if 'waveforms_channel_idxs' in sorting.get_shared_unit_property_names():
channel_idxs = sorting.get_unit_property(unit_id, 'waveforms_channel_idxs')
else:
channel_idxs = np.arange(recording.get_num_channels())
spike_index_list.append(indexes)
channel_index_list.append(channel_idxs)
else:
if dtype is None:
dtype = recording.get_dtype()
Expand Down Expand Up @@ -647,9 +646,9 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret
recompute_info: bool
If True, waveforms are recomputed (default False)
n_jobs: int
Number of jobs for parallelization. Default is None (no parallelization).
Number of jobs for parallelization. Default is None (no parallelization)
joblib_backend: str
The backend for joblib. Default is 'loky'.
The backend for joblib. Default is 'loky'
verbose: bool
If True output is verbose
Expand Down Expand Up @@ -689,12 +688,11 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret
for unit_id in unit_ids:
amplitudes = sorting.get_unit_spike_features(unit_id, 'amplitudes')
amp_list.append(amplitudes)
if return_idxs:
if len(amplitudes) < len(sorting.get_unit_spike_train(unit_id)):
indexes = sorting.get_unit_spike_features(unit_id, 'amplitudes_idxs')
else:
indexes = np.arange(len(amplitudes))
spike_index_list.append(indexes)
if len(amplitudes) < len(sorting.get_unit_spike_train(unit_id)):
indexes = sorting.get_unit_spike_features(unit_id, 'amplitudes_idxs')
else:
indexes = np.arange(len(amplitudes))
spike_index_list.append(indexes)
else:
# pre-construct memmap arrays
if memmap:
Expand Down Expand Up @@ -740,9 +738,9 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret
else:
amp_list[i] = amps

if save_property_or_features:
for i, unit_id in enumerate(unit_ids):
sorting.set_unit_spike_features(unit_id, 'amplitudes', amp_list[i], indexes=spike_index_list[i])
if save_property_or_features:
for i, unit_id in enumerate(unit_ids):
sorting.set_unit_spike_features(unit_id, 'amplitudes', amp_list[i], indexes=spike_index_list[i])

if return_idxs:
return amp_list, spike_index_list
Expand Down Expand Up @@ -809,44 +807,12 @@ def compute_channel_spiking_activity(recording, channel_ids=None, detect_thresho
start_frame = 0
if end_frame is None:
end_frame = recording.get_num_frames()
duration = (end_frame - start_frame) / recording.get_sampling_frequency()

assert np.all([ch in recording.get_channel_ids() for ch in channel_ids]), "Invalid channel_ids"

spike_rates = np.zeros(len(channel_ids))
spike_amplitudes = np.zeros(len(channel_ids))

if n_jobs is None:
n_jobs = 1
if n_jobs == 0:
n_jobs = 1

if start_frame != 0 or end_frame != recording.get_num_frames():
recording_sub = se.SubRecordingExtractor(recording, start_frame=start_frame, end_frame=end_frame)
else:
recording_sub = recording

num_frames = recording_sub.get_num_frames()

# set chunk size
if chunk_size is not None:
chunk_size = int(chunk_size)
elif chunk_mb is not None:
n_bytes = np.dtype(recording.get_dtype()).itemsize
max_size = int(chunk_mb * 1e6) # set Mb per chunk
chunk_size = max_size // (recording.get_num_channels() * n_bytes)

if n_jobs > 1:
chunk_size /= n_jobs

# chunk_size = num_bytes_per_chunk / num_bytes_per_frame
chunks = divide_recording_into_time_chunks(
num_frames=num_frames,
chunk_size=chunk_size,
padding_size=0
)
n_chunk = len(chunks)

if 'spike_rate' in recording.get_shared_channel_property_names() and \
'spike_amplitude' in recording.get_shared_channel_property_names() and not recompute_info:
for i, ch in enumerate(recording.get_channel_ids()):
Expand All @@ -855,7 +821,8 @@ def compute_channel_spiking_activity(recording, channel_ids=None, detect_thresho
else:
sort_detect = st.sortingcomponents.detect_spikes(recording, channel_ids=channel_ids,
detect_threshold=detect_threshold, detect_sign=detect_sign,
n_jobs=n_jobs, start_frame=start_frame, end_frame=end_frame,
n_jobs=n_jobs, joblib_backend=joblib_backend,
start_frame=start_frame, end_frame=end_frame,
verbose=verbose)

for i, unit in enumerate(sort_detect.get_unit_ids()):
Expand Down Expand Up @@ -1058,17 +1025,16 @@ def compute_unit_pca_scores(recording, sorting, unit_ids=None, channel_ids=None,
for unit_id in unit_ids:
pca_scores = sorting.get_unit_spike_features(unit_id, 'pca_scores')
pca_scores_list.append(pca_scores)
if return_idxs:
if len(pca_scores) < len(sorting.get_unit_spike_train(unit_id)):
indexes = sorting.get_unit_spike_features(unit_id, 'pca_scores_idxs')
else:
indexes = np.arange(len(pca_scores))
if 'pca_scores_channel_idxs' in sorting.get_shared_unit_property_names() and by_electrode:
channel_idxs = sorting.get_unit_property(unit_id, 'pca_scores_channel_idxs')
else:
channel_idxs = np.arange(recording.get_num_channels())
spike_index_list.append(indexes)
channel_index_list.append(channel_idxs)
if len(pca_scores) < len(sorting.get_unit_spike_train(unit_id)):
indexes = sorting.get_unit_spike_features(unit_id, 'pca_scores_idxs')
else:
indexes = np.arange(len(pca_scores))
if 'pca_scores_channel_idxs' in sorting.get_shared_unit_property_names() and by_electrode:
channel_idxs = sorting.get_unit_property(unit_id, 'pca_scores_channel_idxs')
else:
channel_idxs = np.arange(recording.get_num_channels())
spike_index_list.append(indexes)
channel_index_list.append(channel_idxs)
else:
if max_spikes_for_pca is None:
max_spikes_for_pca = np.inf
Expand Down
17 changes: 10 additions & 7 deletions spiketoolkit/sortingcomponents/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=-1,
n_shifts=2, n_snippets_for_threshold=10, snippet_size_sec=1,
start_frame=None, end_frame=None, n_jobs=1,
start_frame=None, end_frame=None, n_jobs=1, joblib_backend='loky',
chunk_size=None, chunk_mb=500, verbose=False):
'''
Detects spikes per channel. Spikes are detected as threshold crossings and the threshold is in terms of the median
Expand Down Expand Up @@ -38,13 +38,15 @@ def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=-
end_frame: int
End frame end frame for detection
n_jobs: int
Number of jobs when parallel
Number of jobs for parallelization. Default is None (no parallelization)
joblib_backend: str
The backend for joblib. Default is 'loky'
chunk_size: int
Size of chunks in number of samples. If None, it is automatically calculated
chunk_mb: int
Size of chunks in Mb (default 500 Mb)
verbose: bool
If True output is verbose
If True output is verbose
Returns
-------
Expand Down Expand Up @@ -124,10 +126,11 @@ def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=-
thresholds = detect_threshold * np.median(np.abs(traces_mad) / 0.6745, 1)[:, None]

if n_jobs > 1:
output = Parallel(n_jobs=n_jobs)(delayed(_detect_and_align_peaks_chunk)
(ii, rec_arg, chunks, channel_ids, thresholds, detect_sign,
n_shifts, verbose)
for ii in chunk_iter)
output = Parallel(n_jobs=n_jobs, backend=joblib_backend)(delayed(_detect_and_align_peaks_chunk)
(ii, rec_arg, chunks, channel_ids, thresholds,
detect_sign,
n_shifts, verbose)
for ii in chunk_iter)
for ii, (times_ii, amps_ii) in enumerate(output):
for i, ch in enumerate(channel_ids):
times = times_ii[i]
Expand Down

0 comments on commit 97d4ad2

Please sign in to comment.