forked from deepsound-project/samplernn-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_ignite.py
176 lines (148 loc) · 5.18 KB
/
test_ignite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from model import SampleRNN
from train import *
from trainer import create_supervised_trainer, create_supervised_evaluator
from ignite.engine import Events
from ignite.metrics.loss import Loss
# TODO: accuracy can be used if Loss is working
def main(exp, frame_sizes, dataset, **params):
print('main called')
params = dict(
default_params,
# exp='TEST', frame_sizes=[16,4], dataset='piano',
exp=exp, frame_sizes=frame_sizes, dataset=dataset,
**params
)
results_path = setup_results_dir(params)
tee_stdout(os.path.join(results_path, 'log'))
model = SampleRNN(
frame_sizes=params['frame_sizes'],
n_rnn=params['n_rnn'],
dim=params['dim'],
learn_h0=params['learn_h0'],
q_levels=params['q_levels'],
weight_norm=params['weight_norm']
)
predictor = Predictor(model)
if params['cuda']:
model = model.cuda()
predictor = predictor.cuda()
optimizer = torch.optim.Adam(predictor.parameters())
# optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters()))
loss = sequence_nll_loss_bits
test_split = 1 - params['test_frac']
val_split = test_split - params['val_frac']
data_loader = make_data_loader(model.lookback, params)
train_loader = data_loader(0, val_split, eval=False)
val_loader = data_loader(val_split, test_split, eval=True)
test_loader = data_loader(test_split, 1, eval=True)
trainer = create_supervised_trainer(
predictor, optimizer, loss, params['cuda'])
evaluator = create_supervised_evaluator(predictor,
metrics={
'nll': Loss(loss)
})
# @trainer.on(Events.ITERATION_COMPLETED)
# def log_training_loss(trainer):
# print("Epoch[{}] Loss: {:.2f}".format(
# trainer.state.epoch, trainer.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print("Training Results - Epoch: {} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics['nll']))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print("Validation Results - Epoch: {} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics['nll']))
trainer.run(train_loader, max_epochs=2)
print('train complete!')
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
argument_default=argparse.SUPPRESS
)
def parse_bool(arg):
arg = arg.lower()
if 'true'.startswith(arg):
return True
elif 'false'.startswith(arg):
return False
else:
raise ValueError()
parser.add_argument('--exp', required=True, help='experiment name')
parser.add_argument(
'--frame_sizes', nargs='+', type=int, required=True,
help='frame sizes in terms of the number of lower tier frames, \
starting from the lowest RNN tier'
)
parser.add_argument(
'--dataset', required=True,
help='dataset name - name of a directory in the datasets path \
(settable by --datasets_path)'
)
parser.add_argument(
'--n_rnn', type=int, help='number of RNN layers in each tier'
)
parser.add_argument(
'--dim', type=int, help='number of neurons in every RNN and MLP layer'
)
parser.add_argument(
'--learn_h0', type=parse_bool,
help='whether to learn the initial states of RNNs'
)
parser.add_argument(
'--q_levels', type=int,
help='number of bins in quantization of audio samples'
)
parser.add_argument(
'--seq_len', type=int,
help='how many samples to include in each truncated BPTT pass'
)
parser.add_argument(
'--weight_norm', type=parse_bool,
help='whether to use weight normalization'
)
parser.add_argument('--batch_size', type=int, help='batch size')
parser.add_argument(
'--val_frac', type=float,
help='fraction of data to go into the validation set'
)
parser.add_argument(
'--test_frac', type=float,
help='fraction of data to go into the test set'
)
parser.add_argument(
'--keep_old_checkpoints', type=parse_bool,
help='whether to keep checkpoints from past epochs'
)
parser.add_argument(
'--datasets_path', help='path to the directory containing datasets'
)
parser.add_argument(
'--results_path', help='path to the directory to save the results to'
)
parser.add_argument('--epoch_limit', help='how many epochs to run')
parser.add_argument(
'--resume', type=parse_bool, default=True,
help='whether to resume training from the last checkpoint'
)
parser.add_argument(
'--sample_rate', type=int,
help='sample rate of the training data and generated sound'
)
parser.add_argument(
'--n_samples', type=int,
help='number of samples to generate in each epoch'
)
parser.add_argument(
'--sample_length', type=int,
help='length of each generated sample (in samples)'
)
parser.add_argument(
'--cuda', type=parse_bool,
help='whether to use CUDA'
)
parser.set_defaults(**default_params)
main(**vars(parser.parse_args()))