Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce extractor for Kilosort's temp_wh.dat #1954

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 189 additions & 42 deletions src/spikeinterface/extractors/KilosortTempWhExtractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict
import spikeinterface as si
from pathlib import Path
from spikeinterface import BinaryRecordingExtractor
Expand All @@ -6,50 +7,65 @@
import importlib.util
import runpy
import numpy as np
from spikeinterface import WaveformExtractor, extract_waveforms


class KilosortTempWhExtractor(BinaryRecordingExtractor):
def __init__(self, output_path: Path) -> None:
self.sorter_output_path = output_path / "sorter_output"
# TODO: store opts e.g. ntb, Nbatch etc here.

params = runpy.run_path(self.sorter_output_path / "params.py")
if self.has_spikeinterface(output_path):
self.sorter_output_path = output_path / "sorter_output"

file_paths = Path(self.sorter_output_path) / "temp_wh.dat" # some assert
sampling_frequency = params["sample_rate"]
dtype = params["dtype"]
assert dtype == "int16"
channel_indices = self.get_channel_indices()

channel_map = np.load(self.sorter_output_path / "channel_map.npy")
if channel_map.ndim == 2: # kilosort > 2
channel_indices = channel_map.ravel() # TODO: check multiple shanks
else:
assert channel_map.ndim == 1
channel_indices = channel_map
original_recording = load_extractor(output_path / "spikeinterface_recording.json", base_folder=output_path)
channel_ids = original_recording.get_channel_ids()

num_channels = channel_indices.size
# TODO: check this assumption - does KS change the scale / offset? can check
# by performing no processing...
if original_recording.has_scaled():
gain_to_uV = original_recording.get_property("gain_to_uV")[channel_indices]
offset_to_uV = original_recording.get_property("offset_to_uV")[channel_indices]
else:
gain_to_uV = None
offset_to_uV = None

original_recording = load_extractor(output_path / "spikeinterface_recording.json", base_folder=output_path)
original_channel_ids = original_recording.get_channel_ids()
channel_locations = original_recording.get_channel_locations()

# TODO: I think this is safe to assume as if the recording was
# sorted then it must have a probe attached.
probe = original_recording.get_probe()

elif self.has_valid_sorter_output(output_path):
self.sorter_output_path = output_path

channel_indices = self.get_channel_indices()
channel_ids = np.array(channel_indices, dtype=str)

if original_recording.has_scaled():
gain_to_uV = original_recording.get_property("gain_to_uV")[
channel_indices
] # TODO: check this assumption - does KS change the scale / offset? can check by performing no processing...
offset_to_uV = original_recording.get_property("offset_to_uV")[channel_indices]
else:
gain_to_uV = None
offset_to_uV = None

self.original_recording_num_samples = original_recording.get_num_samples()
new_channel_ids = original_channel_ids[channel_indices] # TODO: check whether this will erroneously re-order
# is_filtered = original_recording.is_filtered or ## params was filtering run
channel_locations = np.load(self.sorter_output_path / "channel_positions.npy")
probe = None

else:
raise ValueError("")

params = self.load_and_check_kilosort_params_file()
temp_wh_path = Path(self.sorter_output_path) / "temp_wh.dat"

new_channel_ids = channel_ids[channel_indices]
new_channel_locations = channel_locations[channel_indices]

# TODO: need to adjust probe?
# TODO: check whether this will erroneously re-order
# is_filtered = original_recording.is_filtered or ## params was filtering run
super(KilosortTempWhExtractor, self).__init__(
file_paths,
sampling_frequency,
dtype,
num_channels=num_channels,
temp_wh_path,
params["sample_rate"],
params["dtype"],
num_channels=channel_indices.size,
t_starts=None,
channel_ids=new_channel_ids,
time_axis=0,
Expand All @@ -59,37 +75,168 @@ def __init__(self, output_path: Path) -> None:
is_filtered=None,
num_chan=None,
)
self.set_channel_locations(new_channel_locations)

# if probe:
# self.set_probe(probe)

def get_channel_indices(self):
""""""
channel_map = np.load(self.sorter_output_path / "channel_map.npy")

if channel_map.ndim == 2:
channel_indices = channel_map.ravel()
else:
assert channel_map.ndim == 1
channel_indices = channel_map

# TODO: check, there must be a probe if sorting was run?
# change the wiring of the probe
# TODO: check this carefully, might be completely wrong
return channel_indices

contact_vector = original_recording.get_property("contact_vector")
contact_vector = contact_vector[channel_indices]
# if contact_vector is not None:
contact_vector["device_channel_indices"] = np.arange(len(new_channel_ids), dtype="int64")
self.set_property("contact_vector", contact_vector)
def has_spikeinterface(self, path_: Path) -> bool:
""" """
sorter_output = path_ / "sorter_output"

data2 = original_recording.get_traces(start_frame=0, end_frame=75000)
breakpoint()
if not (path_ / "spikeinterface_recording.json").is_file() or not sorter_output.is_dir():
return False

return self.has_valid_sorter_output(sorter_output)

def has_valid_sorter_output(self, path_: Path) -> bool:
""" """
required_files = ["temp_wh.dat", "channel_map.npy", "channel_positions.npy"]

for filename in required_files:
if not (path_ / filename).is_file():
print(f"The file {filename} cannot be out in {path_}")
return False
return True

def load_and_check_kilosort_params_file(self) -> Dict:
""" """
params = runpy.run_path(self.sorter_output_path / "params.py")

if params["dtype"] != "int16":
raise ValueError("The dtype in kilosort's params.py is expected" "to be `int16`.")

return params


# original_probe = original_recording.get_probe()
# self.set_probe(original_probe)
# self.set_probe(original_probe) TODO: do we need to adjust the probe? what about contact positions?

# 1) figure out metadata and casting for WaveForm Extractor
# 2) check lazyness etc.

# zero padding can just be kept. Check it plays nice with WaveformExtractor...

# TODO: add provenance

# TODO: what to do about all those zeros?

# def get_num_samples(self):
# """ ignore Kilosort's zero-padding """
# return self.original_recording.get_num_samples()

path_ = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output")
data = KilosortTempWhExtractor(path_)
# TODO: check, there must be a probe if sorting was run?
# change the wiring of the probe
# TODO: check this carefully, might be completely wrong
# if contact_vector is not None:

# if channel_map.ndim == 2: # kilosort > 2
# channel_indices = channel_map.ravel() # TODO: check multiple shanks

# self.set_channel_locations(new_channel_locations) # TOOD: check against slice_channels

# is_filtered=None, # TODO: need to get from KS provenence?

# In general, do we store the full channel map in channel contacts or do we
# only save the new subset? My guess is subset for contact_positions, but full probe
# for probe. Check against slice_channels.
# self.set_probe(probe) # TODO: what does this mean for missing channels?

# if channel_map.ndim == 2: # kilosort > 2
# does kilosort > 2 store shanks differently? channel_indices = channel_map.ravel()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoeZiminski ,

Kilosort 2 no longer uses the shank definition (MouseLand/Kilosort#155), so this is definitely true for Kilosort2. Based on my search of the code base of Kilosort 3 it seems like kcoords (shanks) are still not being used: (https://github.com/search?q=repo%3AMouseLand%2FKilosort%20kcoords&type=code) (ie the code is for loading the kcoords and putting them in rez, but then they aren't actually used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not used in KS2.5 either.
MouseLand/Kilosort#262

path_ = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output") # sorter_output
recording_new = KilosortTempWhExtractor(path_)

from spikeinterface import extractors
from spikeinterface import postprocessing

sorting = extractors.read_kilosort(
folder_path=(path_ / "sorter_output").as_posix(),
keep_good_only=False,
)

recording_old = load_extractor(path_ / "spikeinterface_recording.json", base_folder=path_)
folder_old = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\waveform_folder_old")
waveforms_old = extract_waveforms(
recording_old,
sorting,
folder_old,
ms_before=1.5,
ms_after=2,
max_spikes_per_unit=500,
allow_unfiltered=True,
load_if_exists=True,
) # match kilosort

folder_new = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\waveform_folder_new")
waveforms_new = extract_waveforms(
recording_new,
sorting,
folder_new,
ms_before=1.5,
ms_after=2,
max_spikes_per_unit=500,
allow_unfiltered=True,
load_if_exists=True,
) # match kilosort

breakpoint()

if False:
import matplotlib.pyplot as plt

plt.plot(kilosort_waveform)
plt.show()

plt.plot(test_waveform)
plt.show()


# if folder.is_dir():
# import shutil
# shutil.rmtree(folder)

# run sorting without kilosort preprocessing
# then, the `temp_wh.dat` should match exactly the original file!
# I think this is a solid way to test. It is not possible to test against
#


if False:
original_recording = load_extractor(path_ / "spikeinterface_recording.json", base_folder=path_)
waveforms_old = extract_waveforms(
original_recording,
sorting,
folder,
ms_before=1.5,
ms_after=2.0,
max_spikes_per_unit=500,
allow_unfiltered=True,
load_if_exists=True,
)

original_recording = load_extractor(path_ / "spikeinterface_recording.json", base_folder=path_)

# TODO: unit locations don't match kilosort very well, at least in the 1-spike case.
# But, this could be due to windowing and should average out over many spikes
breakpoint()

unit_locations_old = postprocessing.compute_unit_locations(waveforms, method="center_of_mass", outputs="by_unit")
unit_locations_pandas = pd.DataFrame.from_dict(unit_locations, orient="index", columns=["x", "y"])
unit_locations_pandas.to_csv(unit_locations_path)

utils.message_user(f"Unit locations saved to {unit_locations_path}")

print(we)