diff --git a/expipe/io/axona.py b/expipe/io/axona.py index 1ac690f..b016e68 100644 --- a/expipe/io/axona.py +++ b/expipe/io/axona.py @@ -107,7 +107,8 @@ def generate_analog_signals(exdir_path): lfp_timeseries.attrs["electrode_idx"] = analog_signal.channel_id - axona_channel_group.channel_group_id * 4 lfp_timeseries.attrs['electrode_group_id'] = axona_channel_group.channel_group_id data = lfp_timeseries.require_dataset("data", data=analog_signal.signal) - # NOTE: In exdirio (python-neo) sample rate is required on dset + data.attrs["num_samples"] = len(analog_signal.signal) + # NOTE: In exdirio (python-neo) sample rate is required on dset #TODO data.attrs["sample_rate"] = analog_signal.sample_rate @@ -124,13 +125,16 @@ def generate_clusters(exdir_path): if(axona_channel_group.channel_group_id == cut.channel_group_id): units = np.unique(cut.indices) cluster = channel_group.require_group("Clustering") - cluster.require_dataset("times", data=spike_train.times) - cluster.require_dataset("cluster_nums", data=units) - cluster.require_dataset("nums", data=cut.indices) cluster.attrs["start_time"] = start_time cluster.attrs["stop_time"] = stop_time # TODO: Add _ peak_over_rms as described in NWB cluster.attrs["peak_over_rms"] = None + times = cluster.require_dataset("times", data=spike_train.times) + times.attrs["num_samples"] = len(spike_train.times) + clnums = cluster.require_dataset("cluster_nums", data=units) + clnums.attrs["num_samples"] = len(units) + nums = cluster.require_dataset("nums", data=cut.indices) + nums.attrs["num_samples"] = len(scut.indices) def generate_units(exdir_path): @@ -142,13 +146,13 @@ def generate_units(exdir_path): spike_train = axona_channel_group.spike_train start_time = channel_group_segment['start_time'] stop_time = channel_group_segment['stop_time'] - + for cut in axona_file.cuts: if(axona_channel_group.channel_group_id == cut.channel_group_id): unit_times = channel_group.require_group("UnitTimes") unit_times.attrs["start_time"] = start_time unit_times.attrs["stop_time"] = stop_time - + unit_ids = [i for i in np.unique(cut.indices) if i > 0] unit_ids = np.array(unit_ids) - 1 # -1 for pyhton convention for index in unit_ids: @@ -156,8 +160,8 @@ def generate_units(exdir_path): indices = np.where(cut.indices == index)[0] times = spike_train.times[indices] unit.require_dataset("times", data=times) - - unit.attrs["cluster_group"] = "Unsorted" + unit.attrs['num_samples'] = len(times) + unit.attrs["cluster_group"] = "Unsorted" unit.attrs["cluster_id"] = int(index) # TODO: Add unit_description (e.g. cell type) and source as in NWB unit.attrs["source"] = None @@ -185,11 +189,17 @@ def generate_spike_trains(exdir_path): waveform_timeseries.attrs['electrode_group_id'] = axona_channel_group.channel_group_id waveform_timeseries.attrs["start_time"] = start_time waveform_timeseries.attrs["stop_time"] = stop_time + waveform_timeseries.attrs['sample_rate'] = spike_train.sample_rate if not isinstance(spike_train.waveforms, pq.Quantity): spike_train.waveforms = spike_train.waveforms * pq.uV # TODO fix pyxona data = waveform_timeseries.require_dataset("data", data=spike_train.waveforms) + data.attrs["num_samples"] = spike_train.spike_count + data.attrs["sample_length"] = spike_train.samples_per_spike + data.attrs["num_channels"] = len(channel_identities) data.attrs['sample_rate'] = spike_train.sample_rate - waveform_timeseries.require_dataset("timestamps", data=spike_train.times) + times = waveform_timeseries.require_dataset("timestamps", + data=spike_train.times) + times.attrs["num_samples"] = spike_train.spike_count def generate_tracking(exdir_path): @@ -207,8 +217,11 @@ def generate_tracking(exdir_path): tracked_spots = int(coords.shape[1] / 2) # 2 coordinates per spot for n in range(tracked_spots): led = position.require_group("led_" + str(n)) - led.require_dataset('data', coords[:, n * 2: n * 2 + 1 + 1]) - led.require_dataset("timestamps", times) + data = coords[:, n * 2: n * 2 + 1 + 1] + dset = led.require_dataset('data', data) + dset.attrs['num_samples'] = len(data) + dset = led.require_dataset("timestamps", times) + dset.attrs['num_samples'] = len(times) led.attrs['start_time'] = 0 * pq.s led.attrs['stop_time'] = axona_file._duration @@ -223,25 +236,28 @@ def generate_inp(exdir_path): if not all(key in inp.attrs for key in ['start_time', 'stop_time']): inp.attrs['start_time'] = 0 * pq.s inp.attrs['stop_time'] = axona_file._duration - + inp_data = axona_file.inp_data - inp.require_dataset('timestamps', inp_data.times) - inp.require_dataset('event_types', inp_data.event_types) - inp.require_dataset('values', inp_data.values) + times = inp.require_dataset('timestamps', inp_data.times) + times.attrs['num_samples'] = len(times) + types = inp.require_dataset('event_types', inp_data.event_types) + types.attrs['num_samples'] = len(times) + vals = inp.require_dataset('values', inp_data.values) + vals.attrs['num_samples'] = len(times) class AxonaFilerecord(Filerecord): def __init__(self, action, filerecord_id=None): super().__init__(action, filerecord_id) - + def import_file(self, axona_setfile): convert(axona_filename=axona_setfile, exdir_path=os.path.join(settings["data_path"], self.local_path)) - + def generate_tracking(self): generate_tracking(self.local_path) - + def generate_analog_signals(self): generate_analog_signals(self.local_path) - + def generate_spike_trains(self): generate_spike_trains(self.local_path)