Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : get additional data from HerdingspikesSortingExtractor #3525

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/spikeinterface/extractors/herdingspikesextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,39 @@ def __init__(self, file_path, load_unit_info=True):
spike_ids = self._rf["cluster_id"][()]
unit_ids = np.unique(spike_ids)
spike_times = self._rf["times"][()]
unit_locs = self._rf["centres"][()]

if load_unit_info:
self.load_unit_info()
# if load_unit_info:
# self.load_unit_info()

BaseSorting.__init__(self, sampling_frequency, unit_ids)
self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids))
self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids, unit_locs))
self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info}

self.extra_requirements.append("h5py")

def load_unit_info(self):
# TODO
def get_unit_location(
self,
unit_id,
segment_index=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need segment_index here, does the units chnage over segment ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't work with multi-segmented recordings so I don't have much perspective on this. Since get_unit_spike_train is implemented at both the sorting extractor and sorting segment level I did the same here just in case.

):

segment_index = self._check_segment_index(segment_index)
segment = self._sorting_segments[segment_index]
loc = segment.get_unit_location(unit_id=unit_id)
return loc

"""
def load_unit_info(self):

if 'centres' in self._rf.keys() and len(self._spike_times) > 0:
self._unit_locs = self._rf['centres'][()] # cache for faster access
for u_i, unit_id in enumerate(self._unit_ids):
self.set_unit_property(unit_id, property_name='unit_location', value=self._unit_locs[u_i])
inds = [] # get these only once
for unit_id in self._unit_ids:
inds.append(np.where(self._cluster_id == unit_id)[0])
if 'data' in self._rf.keys() and len(self._spike_times) > 0:
d = self._rf['data'][()]
if 'x' in self._rf.keys() and 'y' in self._rf.keys() and len(self._spike_times) > 0:
x = self._rf['x'][()]
y = self._rf['y'][()]
for i, unit_id in enumerate(self._unit_ids):
self.set_unit_spike_features(unit_id, 'spike_location', d[:, inds[i]].T)
if 'ch' in self._rf.keys() and len(self._spike_times) > 0:
Expand All @@ -79,12 +90,13 @@ def load_unit_info(self):


class HerdingspikesSortingSegment(BaseSortingSegment):
def __init__(self, unit_ids, spike_times, spike_ids):
def __init__(self, unit_ids, spike_times, spike_ids, unit_locs):
BaseSortingSegment.__init__(self)
# spike_times is a dict
self._unit_ids = list(unit_ids)
self._spike_times = spike_times
self._spike_ids = spike_ids
self._unit_locs = unit_locs

def get_unit_spike_train(self, unit_id, start_frame, end_frame):
mask = self._spike_ids == unit_id
Expand All @@ -95,6 +107,9 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
times = times[times < end_frame]
return times

def get_unit_location(self, unit_id):
return self._unit_locs[unit_id]

"""
@staticmethod
def write_sorting(sorting, save_path):
Expand Down