diff --git a/src/spikeinterface/extractors/KilosortTempWhExtractor.py b/src/spikeinterface/extractors/KilosortTempWhExtractor.py index aac0903e36..480e5ba7f1 100644 --- a/src/spikeinterface/extractors/KilosortTempWhExtractor.py +++ b/src/spikeinterface/extractors/KilosortTempWhExtractor.py @@ -1,3 +1,4 @@ +from typing import Dict import spikeinterface as si from pathlib import Path from spikeinterface import BinaryRecordingExtractor @@ -6,50 +7,65 @@ import importlib.util import runpy import numpy as np +from spikeinterface import WaveformExtractor, extract_waveforms class KilosortTempWhExtractor(BinaryRecordingExtractor): def __init__(self, output_path: Path) -> None: - self.sorter_output_path = output_path / "sorter_output" # TODO: store opts e.g. ntb, Nbatch etc here. - params = runpy.run_path(self.sorter_output_path / "params.py") + if self.has_spikeinterface(output_path): + self.sorter_output_path = output_path / "sorter_output" - file_paths = Path(self.sorter_output_path) / "temp_wh.dat" # some assert - sampling_frequency = params["sample_rate"] - dtype = params["dtype"] - assert dtype == "int16" + channel_indices = self.get_channel_indices() - channel_map = np.load(self.sorter_output_path / "channel_map.npy") - if channel_map.ndim == 2: # kilosort > 2 - channel_indices = channel_map.ravel() # TODO: check multiple shanks - else: - assert channel_map.ndim == 1 - channel_indices = channel_map + original_recording = load_extractor(output_path / "spikeinterface_recording.json", base_folder=output_path) + channel_ids = original_recording.get_channel_ids() - num_channels = channel_indices.size + # TODO: check this assumption - does KS change the scale / offset? can check + # by performing no processing... + if original_recording.has_scaled(): + gain_to_uV = original_recording.get_property("gain_to_uV")[channel_indices] + offset_to_uV = original_recording.get_property("offset_to_uV")[channel_indices] + else: + gain_to_uV = None + offset_to_uV = None - original_recording = load_extractor(output_path / "spikeinterface_recording.json", base_folder=output_path) - original_channel_ids = original_recording.get_channel_ids() + channel_locations = original_recording.get_channel_locations() + + # TODO: I think this is safe to assume as if the recording was + # sorted then it must have a probe attached. + probe = original_recording.get_probe() + + elif self.has_valid_sorter_output(output_path): + self.sorter_output_path = output_path + + channel_indices = self.get_channel_indices() + channel_ids = np.array(channel_indices, dtype=str) - if original_recording.has_scaled(): - gain_to_uV = original_recording.get_property("gain_to_uV")[ - channel_indices - ] # TODO: check this assumption - does KS change the scale / offset? can check by performing no processing... - offset_to_uV = original_recording.get_property("offset_to_uV")[channel_indices] - else: gain_to_uV = None offset_to_uV = None - self.original_recording_num_samples = original_recording.get_num_samples() - new_channel_ids = original_channel_ids[channel_indices] # TODO: check whether this will erroneously re-order - # is_filtered = original_recording.is_filtered or ## params was filtering run + channel_locations = np.load(self.sorter_output_path / "channel_positions.npy") + probe = None + + else: + raise ValueError("") + + params = self.load_and_check_kilosort_params_file() + temp_wh_path = Path(self.sorter_output_path) / "temp_wh.dat" + + new_channel_ids = channel_ids[channel_indices] + new_channel_locations = channel_locations[channel_indices] + # TODO: need to adjust probe? + # TODO: check whether this will erroneously re-order + # is_filtered = original_recording.is_filtered or ## params was filtering run super(KilosortTempWhExtractor, self).__init__( - file_paths, - sampling_frequency, - dtype, - num_channels=num_channels, + temp_wh_path, + params["sample_rate"], + params["dtype"], + num_channels=channel_indices.size, t_starts=None, channel_ids=new_channel_ids, time_axis=0, @@ -59,23 +75,54 @@ def __init__(self, output_path: Path) -> None: is_filtered=None, num_chan=None, ) + self.set_channel_locations(new_channel_locations) + + # if probe: + # self.set_probe(probe) + + def get_channel_indices(self): + """""" + channel_map = np.load(self.sorter_output_path / "channel_map.npy") + + if channel_map.ndim == 2: + channel_indices = channel_map.ravel() + else: + assert channel_map.ndim == 1 + channel_indices = channel_map - # TODO: check, there must be a probe if sorting was run? - # change the wiring of the probe - # TODO: check this carefully, might be completely wrong + return channel_indices - contact_vector = original_recording.get_property("contact_vector") - contact_vector = contact_vector[channel_indices] - # if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(new_channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + def has_spikeinterface(self, path_: Path) -> bool: + """ """ + sorter_output = path_ / "sorter_output" - data2 = original_recording.get_traces(start_frame=0, end_frame=75000) - breakpoint() + if not (path_ / "spikeinterface_recording.json").is_file() or not sorter_output.is_dir(): + return False + + return self.has_valid_sorter_output(sorter_output) + + def has_valid_sorter_output(self, path_: Path) -> bool: + """ """ + required_files = ["temp_wh.dat", "channel_map.npy", "channel_positions.npy"] + + for filename in required_files: + if not (path_ / filename).is_file(): + print(f"The file {filename} cannot be out in {path_}") + return False + return True + + def load_and_check_kilosort_params_file(self) -> Dict: + """ """ + params = runpy.run_path(self.sorter_output_path / "params.py") + + if params["dtype"] != "int16": + raise ValueError("The dtype in kilosort's params.py is expected" "to be `int16`.") + + return params # original_probe = original_recording.get_probe() -# self.set_probe(original_probe) +# self.set_probe(original_probe) TODO: do we need to adjust the probe? what about contact positions? # 1) figure out metadata and casting for WaveForm Extractor # 2) check lazyness etc. @@ -83,13 +130,113 @@ def __init__(self, output_path: Path) -> None: # zero padding can just be kept. Check it plays nice with WaveformExtractor... # TODO: add provenance - +# TODO: what to do about all those zeros? # def get_num_samples(self): # """ ignore Kilosort's zero-padding """ # return self.original_recording.get_num_samples() -path_ = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output") -data = KilosortTempWhExtractor(path_) +# TODO: check, there must be a probe if sorting was run? +# change the wiring of the probe +# TODO: check this carefully, might be completely wrong +# if contact_vector is not None: + +# if channel_map.ndim == 2: # kilosort > 2 +# channel_indices = channel_map.ravel() # TODO: check multiple shanks + +# self.set_channel_locations(new_channel_locations) # TOOD: check against slice_channels + +# is_filtered=None, # TODO: need to get from KS provenence? + +# In general, do we store the full channel map in channel contacts or do we +# only save the new subset? My guess is subset for contact_positions, but full probe +# for probe. Check against slice_channels. +# self.set_probe(probe) # TODO: what does this mean for missing channels? + +# if channel_map.ndim == 2: # kilosort > 2 +# does kilosort > 2 store shanks differently? channel_indices = channel_map.ravel() + +path_ = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output") # sorter_output +recording_new = KilosortTempWhExtractor(path_) + +from spikeinterface import extractors +from spikeinterface import postprocessing + +sorting = extractors.read_kilosort( + folder_path=(path_ / "sorter_output").as_posix(), + keep_good_only=False, +) + +recording_old = load_extractor(path_ / "spikeinterface_recording.json", base_folder=path_) +folder_old = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\waveform_folder_old") +waveforms_old = extract_waveforms( + recording_old, + sorting, + folder_old, + ms_before=1.5, + ms_after=2, + max_spikes_per_unit=500, + allow_unfiltered=True, + load_if_exists=True, +) # match kilosort + +folder_new = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\waveform_folder_new") +waveforms_new = extract_waveforms( + recording_new, + sorting, + folder_new, + ms_before=1.5, + ms_after=2, + max_spikes_per_unit=500, + allow_unfiltered=True, + load_if_exists=True, +) # match kilosort breakpoint() + +if False: + import matplotlib.pyplot as plt + + plt.plot(kilosort_waveform) + plt.show() + + plt.plot(test_waveform) + plt.show() + + +# if folder.is_dir(): +# import shutil +# shutil.rmtree(folder) + +# run sorting without kilosort preprocessing +# then, the `temp_wh.dat` should match exactly the original file! +# I think this is a solid way to test. It is not possible to test against +# + + +if False: + original_recording = load_extractor(path_ / "spikeinterface_recording.json", base_folder=path_) + waveforms_old = extract_waveforms( + original_recording, + sorting, + folder, + ms_before=1.5, + ms_after=2.0, + max_spikes_per_unit=500, + allow_unfiltered=True, + load_if_exists=True, + ) + + original_recording = load_extractor(path_ / "spikeinterface_recording.json", base_folder=path_) + + # TODO: unit locations don't match kilosort very well, at least in the 1-spike case. + # But, this could be due to windowing and should average out over many spikes + breakpoint() + + unit_locations_old = postprocessing.compute_unit_locations(waveforms, method="center_of_mass", outputs="by_unit") + unit_locations_pandas = pd.DataFrame.from_dict(unit_locations, orient="index", columns=["x", "y"]) + unit_locations_pandas.to_csv(unit_locations_path) + + utils.message_user(f"Unit locations saved to {unit_locations_path}") + + print(we)