diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 084c3b37ea..df20811fad 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -119,47 +119,21 @@ def _get_locations_from_pc_features(params): Notes ----- - Location of of each individual spike is computed from its low-dimensional projection. - During sorting, kilosort computes the ' - `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. - Taking the first component, the subset of 32 channels associated with this - spike are indexed to get the actual channel locations (in um). Then, the channel - locations are weighted by their PC values. + My understanding so far. KS1 paper; The individual spike waveforms are decomposed into + 'private PCs'. Let the waveform matrix W be time (t) x channel (c). PCA + decompoisition is performed to compute c basis waveforms. Scores for each + channel onto the top three PCs are stored (these recover the waveform well. This function is based on code in Nick Steinmetz's `spikes` repository, https://github.com/cortex-lab/spikes """ # Compute spike depths - - # for each spike, a PCA is computed just on that spike (n samples x n channels). - # the components are all different between spikes, so are not saved. - # This gives a (n pc = 3, num channels) set of scores. - # but then how it is possible for some spikes to have zero score onto the principal channel? - - breakpoint() - pc_features = params["pc_features"][:, 0, :] + pc_features = params["pc_features"][:, 0, :].copy() pc_features[pc_features < 0] = 0 - # Some spikes do not load at all onto the first PC. To avoid biasing the - # dataset by removing these, we repeat the above for the next PC, - # to compute distances for neurons that do not load onto the 1st PC. - # This is not ideal at all, it would be much better to a) find the - # max value for each channel on each of the PCs (i.e. basis vectors). - # Then recompute the estimated waveform peak on each channel by - # summing the PCs by their respective weights. However, the PC basis - # vectors themselves do not appear to be output by KS. - - # We include the (n_channels i.e. features) from the second PC - # into the `pc_features` mostly containing the first PC. As all - # operations are per-spike (i.e. row-wise) - no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0) - - pc_features_2 = params["pc_features"][:, 1, :] - pc_features_2[pc_features_2 < 0] = 0 - - pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes] - if np.any(np.sum(pc_features, axis=1) == 0): + # TODO: 1) handle this case for pc_features + # 2) instead use the template_features for all other versions. raise RuntimeError( "Some spikes do not load at all onto the first" "or second principal component. It is necessary" @@ -343,8 +317,15 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool if load_pcs: pc_features = np.load(sorter_output / "pc_features.npy") pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy") + + if (sorter_output / "template_features.npy").is_file(): + template_features = np.load(sorter_output / "template_features.npy") + template_features_indices = np.load(sorter_output / "templates_ind.npy") + else: + template_features = template_features_indices = None else: pc_features = pc_features_indices = None + template_features = template_features_indices = None # This makes the assumption that there will never be different .csv and .tsv files # in the same sorter output (this should never happen, there will never even be two). @@ -364,6 +345,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool if load_pcs: pc_features = pc_features[not_noise_clusters_by_spike, :, :] + if template_features is not None: + template_features = template_features[not_noise_clusters_by_spike, :, :] spike_clusters = spike_clusters[not_noise_clusters_by_spike] cluster_ids = cluster_ids[cluster_groups != 0] @@ -378,6 +361,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool "spike_clusters": spike_clusters.squeeze(), "pc_features": pc_features, "pc_features_indices": pc_features_indices, + "template_features": template_features, + "template_features_indices": template_features_indices, "temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(), "cluster_ids": cluster_ids, "cluster_groups": cluster_groups,