From 146179c0f48308d5d78b228f96d218ae1828e5b2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 20:58:26 +0000 Subject: [PATCH] Add some notes. --- .../working/load_kilosort_utils.py | 55 ++++--------------- 1 file changed, 10 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index aa3fb2babd..3f50700d66 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -49,7 +49,7 @@ def compute_spike_amplitude_and_depth( Notes ----- - In `_template_positions_amplitudes` spike depths is calculated as simply the template + In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template depth, for each spike (so it is the same for all spikes in a cluster). Here we need to find the depth of each individual spike, using its low-dimensional projection. `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. @@ -101,7 +101,7 @@ def compute_spike_amplitude_and_depth( # multiplied by the `template_scaling_amplitudes`. # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes( + spike_amplitudes, _, _, _, unwhite_templates, *_ = get_template_info_and_spike_amplitudes( params["templates"], params["whitening_matrix_inv"], ycoords, @@ -112,9 +112,16 @@ def compute_spike_amplitude_and_depth( if gain is not None: spike_amplitudes *= gain + max_site = np.argmax( + np.max(np.abs(templates), axis=1), axis=1 + ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here. max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) spike_sites = max_site[params["spike_templates"]] + # TODO: here the max site is the same for all spikes from the same template. + # is this the case for spikeinterface? Should we estimate max-site per spike from + # the PCs? + if localised_spikes_only: # Interpolate the channel ids to location. # Remove spikes > 5 um from average position @@ -134,45 +141,7 @@ def compute_spike_amplitude_and_depth( return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything -def _filter_large_amplitude_spikes( - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - large_amplitude_only_segment_size, -) -> tuple[np.ndarray, ...]: - """ - Return spike properties with only the largest-amplitude spikes included. The probe - is split into egments, and within each segment the mean and std computed. - Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded - Splitting the probe is only done for the exclusion step, the returned array are flat. - - Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns - copies of these arrays containing only the large amplitude spikes. - """ - spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) - - segment_size_um = large_amplitude_only_segment_size - probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um - - for segment_left_edge in probe_segments_left_edges: - segment_right_edge = segment_left_edge + segment_size_um - - spikes_in_seg = np.where(np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge))[ - 0 - ] - spike_amps_in_seg = spike_amplitudes[spikes_in_seg] - is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) - - spike_bool[spikes_in_seg] = is_high_amplitude - - spike_times = spike_times[spike_bool] - spike_amplitudes = spike_amplitudes[spike_bool] - spike_depths = spike_depths[spike_bool] - - return spike_times, spike_amplitudes, spike_depths - - -def _template_positions_amplitudes( +def get_template_info_and_spike_amplitudes( templates: np.ndarray, inverse_whitening_matrix: np.ndarray, ycoords: np.ndarray, @@ -256,9 +225,6 @@ def _template_positions_amplitudes( counts = np.bincount(spike_templates, minlength=num_indices) template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) - # Each spike's depth is the depth of its template - spike_depths = template_depths[spike_templates] - # Get channel with the largest amplitude (take that as the waveform) max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1) @@ -279,7 +245,6 @@ def _template_positions_amplitudes( return ( spike_amplitudes, - spike_depths, template_depths, template_amplitudes, unwhite_templates,