Skip to content

Commit

Permalink
Add some notes.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 13, 2024
1 parent 84b9a17 commit 146179c
Showing 1 changed file with 10 additions and 45 deletions.
55 changes: 10 additions & 45 deletions src/spikeinterface/working/load_kilosort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_spike_amplitude_and_depth(
Notes
-----
In `_template_positions_amplitudes` spike depths is calculated as simply the template
In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template
depth, for each spike (so it is the same for all spikes in a cluster). Here we need
to find the depth of each individual spike, using its low-dimensional projection.
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
Expand Down Expand Up @@ -101,7 +101,7 @@ def compute_spike_amplitude_and_depth(
# multiplied by the `template_scaling_amplitudes`.

# Compute amplitudes, scale if required and drop un-localised spikes before returning.
spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes(
spike_amplitudes, _, _, _, unwhite_templates, *_ = get_template_info_and_spike_amplitudes(
params["templates"],
params["whitening_matrix_inv"],
ycoords,
Expand All @@ -112,9 +112,16 @@ def compute_spike_amplitude_and_depth(
if gain is not None:
spike_amplitudes *= gain

max_site = np.argmax(
np.max(np.abs(templates), axis=1), axis=1
) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1)
spike_sites = max_site[params["spike_templates"]]

# TODO: here the max site is the same for all spikes from the same template.
# is this the case for spikeinterface? Should we estimate max-site per spike from
# the PCs?

if localised_spikes_only:
# Interpolate the channel ids to location.
# Remove spikes > 5 um from average position
Expand All @@ -134,45 +141,7 @@ def compute_spike_amplitude_and_depth(
return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything


def _filter_large_amplitude_spikes(
spike_times: np.ndarray,
spike_amplitudes: np.ndarray,
spike_depths: np.ndarray,
large_amplitude_only_segment_size,
) -> tuple[np.ndarray, ...]:
"""
Return spike properties with only the largest-amplitude spikes included. The probe
is split into egments, and within each segment the mean and std computed.
Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded
Splitting the probe is only done for the exclusion step, the returned array are flat.
Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns
copies of these arrays containing only the large amplitude spikes.
"""
spike_bool = np.zeros_like(spike_amplitudes, dtype=bool)

segment_size_um = large_amplitude_only_segment_size
probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um

for segment_left_edge in probe_segments_left_edges:
segment_right_edge = segment_left_edge + segment_size_um

spikes_in_seg = np.where(np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge))[
0
]
spike_amps_in_seg = spike_amplitudes[spikes_in_seg]
is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1)

spike_bool[spikes_in_seg] = is_high_amplitude

spike_times = spike_times[spike_bool]
spike_amplitudes = spike_amplitudes[spike_bool]
spike_depths = spike_depths[spike_bool]

return spike_times, spike_amplitudes, spike_depths


def _template_positions_amplitudes(
def get_template_info_and_spike_amplitudes(
templates: np.ndarray,
inverse_whitening_matrix: np.ndarray,
ycoords: np.ndarray,
Expand Down Expand Up @@ -256,9 +225,6 @@ def _template_positions_amplitudes(
counts = np.bincount(spike_templates, minlength=num_indices)
template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0)

# Each spike's depth is the depth of its template
spike_depths = template_depths[spike_templates]

# Get channel with the largest amplitude (take that as the waveform)
max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1)

Expand All @@ -279,7 +245,6 @@ def _template_positions_amplitudes(

return (
spike_amplitudes,
spike_depths,
template_depths,
template_amplitudes,
unwhite_templates,
Expand Down

0 comments on commit 146179c

Please sign in to comment.