diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e45e427..cd8bb5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,12 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.4.0] - 2024-05-28 + ++ Add - support for SpikeInterface version >= 0.101.0 (updated API) ++ Add - feature for memoization of spike sorting results (prevent duplicated runs) + + ## [0.3.4] - 2024-03-22 + Add - pytest diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 2db87e39..2e37f721 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -7,13 +7,12 @@ import datajoint as dj import numpy as np import pandas as pd - from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory from . import ephys_report, probe from .readers import kilosort, openephys, spikeglx -log = dj.logger +logger = dj.logger schema = dj.schema() @@ -822,7 +821,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): if mkdir: output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") + logger.info(f"{output_dir} created!") return output_dir.relative_to(processed_dir) if relative else output_dir @@ -1028,108 +1027,81 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") - si_waveform_dir = output_dir / sorter_name / "waveform" - si_sorting_dir = output_dir / sorter_name / "spike_sorting" + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" - if si_waveform_dir.exists(): # Read from spikeinterface outputs + if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs import spikeinterface as si - from spikeinterface import sorters - we: si.WaveformExtractor = si.load_waveforms( - si_waveform_dir, with_recording=False - ) - si_sorting: si.sorters.BaseSorter = si.load_extractor( - si_sorting_dir / "si_sorting.pkl", base_folder=output_dir - ) + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + si_sorting = sorting_analyzer.sorting - unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_id} + # Find representative channel for each unit + unit_peak_channel: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + sorting_analyzer, + 1, + ).unit_id_to_channel_indices + ) + unit_peak_channel: dict[int, int] = { + u: chn[0] for u, chn in unit_peak_channel.items() + } spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector() - - # reorder channel2electrode_map according to recording channel ids + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] - for chn_idx in we.channel_ids_to_indices(we.channel_ids) + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } # Get unit id to quality label mapping - try: - cluster_quality_label_map = pd.read_csv( - si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", - delimiter="\t", + cluster_quality_label_map = { + int(unit_id): ( + si_sorting.get_unit_property(unit_id, "KSLabel") + if "KSLabel" in si_sorting.get_property_keys() + else "n.a." ) - except FileNotFoundError: - cluster_quality_label_map = {} - else: - cluster_quality_label_map: dict[ - int, str - ] = cluster_quality_label_map.set_index("cluster_id")[ - "KSLabel" - ].to_dict() # {unit: quality_label} - - # Get electrode where peak unit activity is recorded - peak_electrode_ind = np.array( - [ - channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] - for unit_id in si_sorting.unit_ids - ] - ) - - # Get channel depth - channel_depth_ind = np.array( - [ - we.get_probe().contact_positions[unit_peak_channel[unit_id]][1] - for unit_id in si_sorting.unit_ids - ] - ) - - # Assign electrode and depth for each spike - new_spikes = np.empty( - spikes.shape, - spikes.dtype.descr + [("electrode", "=4.2.0", "pyopenephys>=1.1.6", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "numba", ], extras_require={