Skip to content

Commit

Permalink
Look into the template option.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 14, 2024
1 parent b5c85ff commit 250eb54
Showing 1 changed file with 18 additions and 33 deletions.
51 changes: 18 additions & 33 deletions src/spikeinterface/working/load_kilosort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,47 +119,21 @@ def _get_locations_from_pc_features(params):
Notes
-----
Location of of each individual spike is computed from its low-dimensional projection.
During sorting, kilosort computes the '
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
Taking the first component, the subset of 32 channels associated with this
spike are indexed to get the actual channel locations (in um). Then, the channel
locations are weighted by their PC values.
My understanding so far. KS1 paper; The individual spike waveforms are decomposed into
'private PCs'. Let the waveform matrix W be time (t) x channel (c). PCA
decompoisition is performed to compute c basis waveforms. Scores for each
channel onto the top three PCs are stored (these recover the waveform well.
This function is based on code in Nick Steinmetz's `spikes` repository,
https://github.com/cortex-lab/spikes
"""
# Compute spike depths

# for each spike, a PCA is computed just on that spike (n samples x n channels).
# the components are all different between spikes, so are not saved.
# This gives a (n pc = 3, num channels) set of scores.
# but then how it is possible for some spikes to have zero score onto the principal channel?

breakpoint()
pc_features = params["pc_features"][:, 0, :]
pc_features = params["pc_features"][:, 0, :].copy()
pc_features[pc_features < 0] = 0

# Some spikes do not load at all onto the first PC. To avoid biasing the
# dataset by removing these, we repeat the above for the next PC,
# to compute distances for neurons that do not load onto the 1st PC.
# This is not ideal at all, it would be much better to a) find the
# max value for each channel on each of the PCs (i.e. basis vectors).
# Then recompute the estimated waveform peak on each channel by
# summing the PCs by their respective weights. However, the PC basis
# vectors themselves do not appear to be output by KS.

# We include the (n_channels i.e. features) from the second PC
# into the `pc_features` mostly containing the first PC. As all
# operations are per-spike (i.e. row-wise)
no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0)

pc_features_2 = params["pc_features"][:, 1, :]
pc_features_2[pc_features_2 < 0] = 0

pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes]

if np.any(np.sum(pc_features, axis=1) == 0):
# TODO: 1) handle this case for pc_features
# 2) instead use the template_features for all other versions.
raise RuntimeError(
"Some spikes do not load at all onto the first"
"or second principal component. It is necessary"
Expand Down Expand Up @@ -343,8 +317,15 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
if load_pcs:
pc_features = np.load(sorter_output / "pc_features.npy")
pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy")

if (sorter_output / "template_features.npy").is_file():
template_features = np.load(sorter_output / "template_features.npy")
template_features_indices = np.load(sorter_output / "templates_ind.npy")
else:
template_features = template_features_indices = None
else:
pc_features = pc_features_indices = None
template_features = template_features_indices = None

# This makes the assumption that there will never be different .csv and .tsv files
# in the same sorter output (this should never happen, there will never even be two).
Expand All @@ -364,6 +345,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool

if load_pcs:
pc_features = pc_features[not_noise_clusters_by_spike, :, :]
if template_features is not None:
template_features = template_features[not_noise_clusters_by_spike, :, :]

spike_clusters = spike_clusters[not_noise_clusters_by_spike]
cluster_ids = cluster_ids[cluster_groups != 0]
Expand All @@ -378,6 +361,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
"spike_clusters": spike_clusters.squeeze(),
"pc_features": pc_features,
"pc_features_indices": pc_features_indices,
"template_features": template_features,
"template_features_indices": template_features_indices,
"temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(),
"cluster_ids": cluster_ids,
"cluster_groups": cluster_groups,
Expand Down

0 comments on commit 250eb54

Please sign in to comment.