-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain_measure_vae.py
executable file
·131 lines (121 loc) · 4.15 KB
/
train_measure_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import click
from DatasetManager.dataset_manager import DatasetManager
from DatasetManager.the_session.folk_dataset import FolkDataset
from DatasetManager.metadata import TickMetadata, BeatMarkerMetadata
from MeasureVAE.measure_vae import MeasureVAE
from MeasureVAE.vae_trainer import VAETrainer
from MeasureVAE.vae_tester import VAETester
from utils.helpers import *
@click.command()
@click.option('--note_embedding_dim', default=10,
help='size of the note embeddings')
@click.option('--metadata_embedding_dim', default=2,
help='size of the metadata embeddings')
@click.option('--num_encoder_layers', default=2,
help='number of layers in encoder RNN')
@click.option('--encoder_hidden_size', default=512,
help='hidden size of the encoder RNN')
@click.option('--encoder_dropout_prob', default=0.5,
help='float, amount of dropout prob between encoder RNN layers')
@click.option('--has_metadata', default=False,
help='bool, True if data contains metadata')
@click.option('--latent_space_dim', default=256,
help='int, dimension of latent space parameters')
@click.option('--num_decoder_layers', default=2,
help='int, number of layers in decoder RNN')
@click.option('--decoder_hidden_size', default=512,
help='int, hidden size of the decoder RNN')
@click.option('--decoder_dropout_prob', default=0.5,
help='float, amount got dropout prob between decoder RNN layers')
@click.option('--batch_size', default=256,
help='training batch size')
@click.option('--num_epochs', default=30,
help='number of training epochs')
@click.option('--train/--test', default=True,
help='train or retrain the specified model')
@click.option('--plot/--no_plot', default=False,
help='plot the training log')
@click.option('--log/--no_log', default=True,
help='log the results for tensorboard')
@click.option('--lr', default=1e-4,
help='learning rate')
def main(note_embedding_dim,
metadata_embedding_dim,
num_encoder_layers,
encoder_hidden_size,
encoder_dropout_prob,
latent_space_dim,
num_decoder_layers,
decoder_hidden_size,
decoder_dropout_prob,
has_metadata,
batch_size,
num_epochs,
train,
plot,
log,
lr
):
dataset_manager = DatasetManager()
metadatas = [
BeatMarkerMetadata(subdivision=6),
TickMetadata(subdivision=6)
]
mvae_train_kwargs = {
'metadatas': metadatas,
'sequences_size': 32,
'num_bars': 16,
'train': True
}
mvae_test_kwargs = {
'metadatas': metadatas,
'sequences_size': 32,
'num_bars': 16,
'train': False
}
folk_dataset: FolkDataset = dataset_manager.get_dataset(
name='folk_4by4nbars_train',
**mvae_train_kwargs
)
folk_dataset_test: FolkDataset = dataset_manager.get_dataset(
name='folk_4by4nbars_train',
**mvae_test_kwargs
)
model = MeasureVAE(
dataset=folk_dataset,
note_embedding_dim=note_embedding_dim,
metadata_embedding_dim=metadata_embedding_dim,
num_encoder_layers=num_encoder_layers,
encoder_hidden_size=encoder_hidden_size,
encoder_dropout_prob=encoder_dropout_prob,
latent_space_dim=latent_space_dim,
num_decoder_layers=num_decoder_layers,
decoder_hidden_size=decoder_hidden_size,
decoder_dropout_prob=decoder_dropout_prob,
has_metadata=has_metadata
)
if train:
if torch.cuda.is_available():
model.cuda()
trainer = VAETrainer(
dataset=folk_dataset,
model=model,
lr=lr
)
trainer.train_model(
batch_size=batch_size,
num_epochs=num_epochs,
plot=plot,
log=log
)
else:
model.load()
model.cuda()
model.eval()
tester = VAETester(
dataset=folk_dataset_test,
model=model
)
tester.test_model()
if __name__ == '__main__':
main()