-
Notifications
You must be signed in to change notification settings - Fork 7
/
data_loader.py
71 lines (57 loc) · 2.25 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Data loader logic with two main responsibilities:
(i) download raw data and process; this logic is initiated upon import
(ii) helper functions for dealing with mini-batches, sequence packing, etc.
Data are taken from
Boulanger-Lewandowski, N., Bengio, Y. and Vincent, P.,
"Modeling Temporal Dependencies in High-Dimensional Sequences: Application to
Polyphonic Music Generation and Transcription"
however, the original source of the data seems to be the Institut fuer Algorithmen
und Kognitive Systeme at Universitaet Karlsruhe.
"""
import os
import numpy as np
import six.moves.cPickle as pickle
import torch
import torch.nn as nn
import torch.utils.data as data
class PolyphonicDataset(data.Dataset):
def __init__(self, filepath):
# 1. Initialize file path or list of file names.
"""read training sequences(list of int array) from a pickle file"""
print("loading data...")
f= open(filepath, "rb")
data = pickle.load(f)
self.seqs = data['sequences']
self.seqlens = data['seq_lens']
self.data_len = len(self.seqs)
print("{} entries".format(self.data_len))
def __getitem__(self, offset):
seq=self.seqs[offset].astype('float32')
rev_seq= seq.copy()
rev_seq[0:len(seq), :] = seq[(len(seq)-1)::-1, :]
seq_len=self.seqlens[offset].astype('int64')
return seq, rev_seq, seq_len
def __len__(self):
return self.data_len
class SyntheticDataset(data.Dataset):
def __init__(self, filepath):
# 1. Initialize file path or list of file names.
"""read training sequences(list of int array) from a pickle file"""
print("loading data...")
f= open(filepath, "rb")
data = pickle.load(f)
self.seqs = data['sequences']
self.seqlens = data['seq_lens']
self.z = data['z']
self.data_len = len(self.seqs)
print("{} entries".format(self.data_len))
def __getitem__(self, offset):
seq=self.seqs[offset].astype('float32')
rev_seq= seq.copy()
rev_seq[0:len(seq), :] = seq[(len(seq)-1)::-1, :]
seq_len=self.seqlens[offset].astype('int64')
z = self.z[offset]
return seq, rev_seq, seq_len, z
def __len__(self):
return self.data_len