Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 7, 2024
1 parent caf15df commit 817302f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
11 changes: 6 additions & 5 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@ def _merge_extension_data(
new_data["waveforms"] = waveforms.copy()
for to_be_merged, unit_id in zip(units_to_merge, new_unit_ids):
new_channel_ids = sparsity.unit_id_to_channel_ids[unit_id]
new_waveforms, spike_indices = self.get_some_waveforms(new_channel_ids,
to_be_merged,
kept_waveforms=new_data["waveforms"],
kept_spikes=some_spikes)
new_waveforms, spike_indices = self.get_some_waveforms(
new_channel_ids, to_be_merged, kept_waveforms=new_data["waveforms"], kept_spikes=some_spikes
)
num_chans = new_waveforms.shape[2]
new_data["waveforms"][spike_indices, :, :num_chans] = new_waveforms
new_data["waveforms"][spike_indices, :, num_chans:] = 0
Expand Down Expand Up @@ -382,7 +381,9 @@ def get_some_waveforms(self, channel_ids=None, unit_ids=None, kept_waveforms=Non

for unit_id in unit_ids:
unit_index = sorting.id_to_index(unit_id)
sparse_waveforms = self.get_waveforms_one_unit(unit_id, kept_waveforms=kept_waveforms, kept_spikes=kept_spikes)
sparse_waveforms = self.get_waveforms_one_unit(
unit_id, kept_waveforms=kept_waveforms, kept_spikes=kept_spikes
)
local_chan_inds = sparsity.unit_id_to_channel_indices[unit_id]

# keep only requested channels
Expand Down
13 changes: 7 additions & 6 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ def _merge_extension_data(
new_data["pca_projection"] = pca_projections.copy()
for to_be_merge, unit_id in zip(units_to_merge, new_unit_ids):
new_channel_ids = sparsity.unit_id_to_channel_ids[unit_id]
new_projections, spike_indices = self.get_some_projections(new_channel_ids,
to_be_merge,
kept_projections=pca_projections,
kept_spikes=some_spikes)
new_projections, spike_indices = self.get_some_projections(
new_channel_ids, to_be_merge, kept_projections=pca_projections, kept_spikes=some_spikes
)
num_chans = new_projections.shape[2]
new_data["pca_projection"][spike_indices, :, :num_chans] = new_projections
new_data["pca_projection"][spike_indices, :, num_chans:] = 0
Expand Down Expand Up @@ -246,7 +245,7 @@ def get_some_projections(self, channel_ids=None, unit_ids=None, kept_projections
channel_indices = self.sorting_analyzer.channel_ids_to_indices(channel_ids)

# note : internally when sparse PCA are not aligned!! Exactly like waveforms.

sparsity = self.sorting_analyzer.sparsity

if kept_projections is not None:
Expand All @@ -273,7 +272,9 @@ def get_some_projections(self, channel_ids=None, unit_ids=None, kept_projections
some_projections = np.zeros((selected_inds.size, num_components, channel_indices.size), dtype=dtype)
for unit_id in unit_ids:
unit_index = sorting.id_to_index(unit_id)
sparse_projection = self.get_projections_one_unit(unit_id, sparse=True, kept_projections=kept_projections, kept_spikes=kept_spikes)
sparse_projection = self.get_projections_one_unit(
unit_id, sparse=True, kept_projections=kept_projections, kept_spikes=kept_spikes
)
local_chan_inds = sparsity.unit_id_to_channel_indices[unit_id]
# keep only requested channels
channel_mask_local_inds = np.isin(local_chan_inds, channel_indices)
Expand Down

0 comments on commit 817302f

Please sign in to comment.