diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 8eadba49..891cee0f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,10 +1045,13 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, + 1, ).unit_id_to_channel_indices ) - unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} + unit_peak_channel: dict[int, int] = { + u: chn[0] for u, chn in unit_peak_channel.items() + } spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1084,7 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates - spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + spike_times = si_sorting.get_unit_spike_train( + unit_id, return_times=True + ) assert len(spike_times) == len(spike_sites) == len(spike_depths) @@ -1243,13 +1248,13 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, 1 ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1272,7 +1277,9 @@ def yield_unit_waveforms(): ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], + "peak_electrode_waveform": unit_waveforms[ + :, unit_peak_channel[unit["unit"]] + ], } unit_electrode_waveforms = [ @@ -1495,7 +1502,7 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 93619303..57aa0ba1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -205,7 +205,7 @@ def _run_sorter(): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=sorting_output_dir, + folder=sorting_output_dir, remove_existing_folder=True, verbose=True, docker_image=sorter_name not in si.sorters.installed_sorters(),