Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce extractor for Kilosort's temp_wh.dat #1954

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/spikeinterface/extractors/KilosortTempWhExtractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import spikeinterface as si
from pathlib import Path
from spikeinterface import BinaryRecordingExtractor
from spikeinterface import load_extractor
import sys
import importlib.util
import runpy
import numpy as np


class KilosortTempWhExtractor(BinaryRecordingExtractor):
def __init__(self, output_path: Path) -> None:
self.sorter_output_path = output_path / "sorter_output"
# TODO: store opts e.g. ntb, Nbatch etc here.

params = runpy.run_path(self.sorter_output_path / "params.py")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! I didn't know about it!


file_paths = Path(self.sorter_output_path) / "temp_wh.dat" # some assert
sampling_frequency = params["sample_rate"]
dtype = params["dtype"]
assert dtype == "int16"

channel_map = np.load(self.sorter_output_path / "channel_map.npy")
if channel_map.ndim == 2: # kilosort > 2
channel_indices = channel_map.ravel() # TODO: check multiple shanks
else:
assert channel_map.ndim == 1
channel_indices = channel_map

num_channels = channel_indices.size

original_recording = load_extractor(output_path / "spikeinterface_recording.json", base_folder=output_path)
original_channel_ids = original_recording.get_channel_ids()

if original_recording.has_scaled():
gain_to_uV = original_recording.get_property("gain_to_uV")[
channel_indices
] # TODO: check this assumption - does KS change the scale / offset? can check by performing no processing...
offset_to_uV = original_recording.get_property("offset_to_uV")[channel_indices]
else:
gain_to_uV = None
offset_to_uV = None

self.original_recording_num_samples = original_recording.get_num_samples()
new_channel_ids = original_channel_ids[channel_indices] # TODO: check whether this will erroneously re-order
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be optional. In case someone runs KS ourside of SI, this class should still be able to load it and infer the channels etc. somewhere else, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the chanMap.mat ;) (if available)!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @alejoe91 thanks for this! sorry was at a conference last week.

That's a good point, it should not be dependent on having previously run sorting with spikeinterface. The passed filepath could be checked and if it contains spikeinterface_recording.json + sorter_output, I think it is safe to assume it is SI-run sorting. Otherwise, we can check if it has some key files (temp_wh.dat, channel_map.npy) and assume it is the sorter output.

I guess the channel information, positions can be read from channel_map.npy and channel_positions.npy. In this case, the channel IDs could just be the channel index as a string? And I don't think the uV scaling can be recovered.

Otherwise, if it is loaded from an SI recording, we can recover the uV scaling and the channel ids can be called by their proper channel ID (as currently done). Otherwise these two methods (loading with SI metadata vs. loading from kilosort output) should be equivilent.

Does that make sense / sound reasonable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfectly reasonable! That's exactly what I was suggesting ;)

# is_filtered = original_recording.is_filtered or ## params was filtering run

super(KilosortTempWhExtractor, self).__init__(
file_paths,
sampling_frequency,
dtype,
num_channels=num_channels,
t_starts=None,
channel_ids=new_channel_ids,
time_axis=0,
file_offset=0,
gain_to_uV=gain_to_uV,
offset_to_uV=offset_to_uV,
is_filtered=None,
num_chan=None,
)

# TODO: check, there must be a probe if sorting was run?
# change the wiring of the probe
# TODO: check this carefully, might be completely wrong

contact_vector = original_recording.get_property("contact_vector")
contact_vector = contact_vector[channel_indices]
# if contact_vector is not None:
contact_vector["device_channel_indices"] = np.arange(len(new_channel_ids), dtype="int64")
self.set_property("contact_vector", contact_vector)

data2 = original_recording.get_traces(start_frame=0, end_frame=75000)
breakpoint()


# original_probe = original_recording.get_probe()
# self.set_probe(original_probe)

# 1) figure out metadata and casting for WaveForm Extractor
# 2) check lazyness etc.

# zero padding can just be kept. Check it plays nice with WaveformExtractor...

# TODO: add provenance


# def get_num_samples(self):
# """ ignore Kilosort's zero-padding """
# return self.original_recording.get_num_samples()

path_ = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output")
data = KilosortTempWhExtractor(path_)

breakpoint()