Skip to content

Commit

Permalink
feat: replace output_folder with folder when calling run_sorter
Browse files Browse the repository at this point in the history
…, use default value for `peak_sign`
  • Loading branch information
ttngu207 committed Jun 3, 2024
1 parent b459709 commit 1a1b18f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
21 changes: 14 additions & 7 deletions element_array_ephys/ephys_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,13 @@ def make(self, key):
# Find representative channel for each unit
unit_peak_channel: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
sorting_analyzer, 1, peak_sign="both"
sorting_analyzer,
1,
).unit_id_to_channel_indices
)
unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()}
unit_peak_channel: dict[int, int] = {
u: chn[0] for u, chn in unit_peak_channel.items()
}

spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
# {unit: spike_count}
Expand Down Expand Up @@ -1084,7 +1087,9 @@ def make(self, key):
)
unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
_, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True)
spike_times = si_sorting.get_unit_spike_train(
unit_id, return_times=True
)

assert len(spike_times) == len(spike_sites) == len(spike_depths)

Expand Down Expand Up @@ -1243,13 +1248,13 @@ def make(self, key):
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
import spikeinterface as si

sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)

# Find representative channel for each unit
unit_peak_channel: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
sorting_analyzer, 1, peak_sign="both"
sorting_analyzer, 1
).unit_id_to_channel_indices
) # {unit: peak_channel_index}
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
Expand All @@ -1272,7 +1277,9 @@ def yield_unit_waveforms():
)
unit_peak_waveform = {
**unit,
"peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]],
"peak_electrode_waveform": unit_waveforms[
:, unit_peak_channel[unit["unit"]]
],
}

unit_electrode_waveforms = [
Expand Down Expand Up @@ -1495,7 +1502,7 @@ def make(self, key):
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
import spikeinterface as si

sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
template_metrics = sorting_analyzer.get_extension(
Expand Down
2 changes: 1 addition & 1 deletion element_array_ephys/spike_sorting/si_spike_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _run_sorter():
si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
sorter_name=sorter_name,
recording=si_recording,
output_folder=sorting_output_dir,
folder=sorting_output_dir,
remove_existing_folder=True,
verbose=True,
docker_image=sorter_name not in si.sorters.installed_sorters(),
Expand Down

0 comments on commit 1a1b18f

Please sign in to comment.