Skip to content

Commit

Permalink
add num samples to all dataset for visualization purpouses and add sa…
Browse files Browse the repository at this point in the history
…mple rate to EventWaveform group #8
  • Loading branch information
lepmik committed Feb 21, 2017
1 parent 77a95cc commit 8a158a2
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions expipe/io/axona.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def generate_analog_signals(exdir_path):
lfp_timeseries.attrs["electrode_idx"] = analog_signal.channel_id - axona_channel_group.channel_group_id * 4
lfp_timeseries.attrs['electrode_group_id'] = axona_channel_group.channel_group_id
data = lfp_timeseries.require_dataset("data", data=analog_signal.signal)
# NOTE: In exdirio (python-neo) sample rate is required on dset
data.attrs["num_samples"] = len(analog_signal.signal)
# NOTE: In exdirio (python-neo) sample rate is required on dset #TODO
data.attrs["sample_rate"] = analog_signal.sample_rate


Expand All @@ -124,13 +125,16 @@ def generate_clusters(exdir_path):
if(axona_channel_group.channel_group_id == cut.channel_group_id):
units = np.unique(cut.indices)
cluster = channel_group.require_group("Clustering")
cluster.require_dataset("times", data=spike_train.times)
cluster.require_dataset("cluster_nums", data=units)
cluster.require_dataset("nums", data=cut.indices)
cluster.attrs["start_time"] = start_time
cluster.attrs["stop_time"] = stop_time
# TODO: Add _ peak_over_rms as described in NWB
cluster.attrs["peak_over_rms"] = None
times = cluster.require_dataset("times", data=spike_train.times)
times.attrs["num_samples"] = len(spike_train.times)
clnums = cluster.require_dataset("cluster_nums", data=units)
clnums.attrs["num_samples"] = len(units)
nums = cluster.require_dataset("nums", data=cut.indices)
nums.attrs["num_samples"] = len(scut.indices)


def generate_units(exdir_path):
Expand All @@ -142,22 +146,22 @@ def generate_units(exdir_path):
spike_train = axona_channel_group.spike_train
start_time = channel_group_segment['start_time']
stop_time = channel_group_segment['stop_time']

for cut in axona_file.cuts:
if(axona_channel_group.channel_group_id == cut.channel_group_id):
unit_times = channel_group.require_group("UnitTimes")
unit_times.attrs["start_time"] = start_time
unit_times.attrs["stop_time"] = stop_time

unit_ids = [i for i in np.unique(cut.indices) if i > 0]
unit_ids = np.array(unit_ids) - 1 # -1 for pyhton convention
for index in unit_ids:
unit = unit_times.require_group("unit_{}".format(index))
indices = np.where(cut.indices == index)[0]
times = spike_train.times[indices]
unit.require_dataset("times", data=times)

unit.attrs["cluster_group"] = "Unsorted"
unit.attrs['num_samples'] = len(times)
unit.attrs["cluster_group"] = "Unsorted"
unit.attrs["cluster_id"] = int(index)
# TODO: Add unit_description (e.g. cell type) and source as in NWB
unit.attrs["source"] = None
Expand Down Expand Up @@ -185,11 +189,17 @@ def generate_spike_trains(exdir_path):
waveform_timeseries.attrs['electrode_group_id'] = axona_channel_group.channel_group_id
waveform_timeseries.attrs["start_time"] = start_time
waveform_timeseries.attrs["stop_time"] = stop_time
waveform_timeseries.attrs['sample_rate'] = spike_train.sample_rate
if not isinstance(spike_train.waveforms, pq.Quantity):
spike_train.waveforms = spike_train.waveforms * pq.uV # TODO fix pyxona
data = waveform_timeseries.require_dataset("data", data=spike_train.waveforms)
data.attrs["num_samples"] = spike_train.spike_count
data.attrs["sample_length"] = spike_train.samples_per_spike
data.attrs["num_channels"] = len(channel_identities)
data.attrs['sample_rate'] = spike_train.sample_rate
waveform_timeseries.require_dataset("timestamps", data=spike_train.times)
times = waveform_timeseries.require_dataset("timestamps",
data=spike_train.times)
times.attrs["num_samples"] = spike_train.spike_count


def generate_tracking(exdir_path):
Expand All @@ -207,8 +217,11 @@ def generate_tracking(exdir_path):
tracked_spots = int(coords.shape[1] / 2) # 2 coordinates per spot
for n in range(tracked_spots):
led = position.require_group("led_" + str(n))
led.require_dataset('data', coords[:, n * 2: n * 2 + 1 + 1])
led.require_dataset("timestamps", times)
data = coords[:, n * 2: n * 2 + 1 + 1]
dset = led.require_dataset('data', data)
dset.attrs['num_samples'] = len(data)
dset = led.require_dataset("timestamps", times)
dset.attrs['num_samples'] = len(times)
led.attrs['start_time'] = 0 * pq.s
led.attrs['stop_time'] = axona_file._duration

Expand All @@ -223,25 +236,28 @@ def generate_inp(exdir_path):
if not all(key in inp.attrs for key in ['start_time', 'stop_time']):
inp.attrs['start_time'] = 0 * pq.s
inp.attrs['stop_time'] = axona_file._duration

inp_data = axona_file.inp_data
inp.require_dataset('timestamps', inp_data.times)
inp.require_dataset('event_types', inp_data.event_types)
inp.require_dataset('values', inp_data.values)
times = inp.require_dataset('timestamps', inp_data.times)
times.attrs['num_samples'] = len(times)
types = inp.require_dataset('event_types', inp_data.event_types)
types.attrs['num_samples'] = len(times)
vals = inp.require_dataset('values', inp_data.values)
vals.attrs['num_samples'] = len(times)


class AxonaFilerecord(Filerecord):
def __init__(self, action, filerecord_id=None):
super().__init__(action, filerecord_id)

def import_file(self, axona_setfile):
convert(axona_filename=axona_setfile, exdir_path=os.path.join(settings["data_path"], self.local_path))

def generate_tracking(self):
generate_tracking(self.local_path)

def generate_analog_signals(self):
generate_analog_signals(self.local_path)

def generate_spike_trains(self):
generate_spike_trains(self.local_path)

0 comments on commit 8a158a2

Please sign in to comment.