diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index de4929218b..a83fbbb838 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -43,28 +43,41 @@ def __init__(self, file_path, load_unit_info=True): spike_ids = self._rf["cluster_id"][()] unit_ids = np.unique(spike_ids) spike_times = self._rf["times"][()] + unit_locs = self._rf["centres"][()] - if load_unit_info: - self.load_unit_info() + self.unit_locations = unit_locs + + # if load_unit_info: + # self.load_unit_info() BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) + self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids, self.unit_locations)) self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info} self.extra_requirements.append("h5py") - def load_unit_info(self): - # TODO + def get_unit_location( + self, + unit_id, + segment_index=None, + ): + + segment_index = self._check_segment_index(segment_index) + segment = self._sorting_segments[segment_index] + loc = segment.get_unit_location(unit_id=unit_id) + return loc + """ + def load_unit_info(self): + if 'centres' in self._rf.keys() and len(self._spike_times) > 0: self._unit_locs = self._rf['centres'][()] # cache for faster access - for u_i, unit_id in enumerate(self._unit_ids): - self.set_unit_property(unit_id, property_name='unit_location', value=self._unit_locs[u_i]) inds = [] # get these only once for unit_id in self._unit_ids: inds.append(np.where(self._cluster_id == unit_id)[0]) - if 'data' in self._rf.keys() and len(self._spike_times) > 0: - d = self._rf['data'][()] + if 'x' in self._rf.keys() and 'y' in self._rf.keys() and len(self._spike_times) > 0: + x = self._rf['x'][()] + y = self._rf['y'][()] for i, unit_id in enumerate(self._unit_ids): self.set_unit_spike_features(unit_id, 'spike_location', d[:, inds[i]].T) if 'ch' in self._rf.keys() and len(self._spike_times) > 0: @@ -79,12 +92,13 @@ def load_unit_info(self): class HerdingspikesSortingSegment(BaseSortingSegment): - def __init__(self, unit_ids, spike_times, spike_ids): + def __init__(self, unit_ids, spike_times, spike_ids, unit_locs): BaseSortingSegment.__init__(self) # spike_times is a dict self._unit_ids = list(unit_ids) self._spike_times = spike_times self._spike_ids = spike_ids + self._unit_locs = unit_locs def get_unit_spike_train(self, unit_id, start_frame, end_frame): mask = self._spike_ids == unit_id @@ -95,6 +109,9 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): times = times[times < end_frame] return times + def get_unit_location(self, unit_id): + return self._unit_locs[unit_id] + """ @staticmethod def write_sorting(sorting, save_path):