Skip to content

Commit

Permalink
분산트레이닝 추가, [#1] 제너레이션 코드 추가, 주피터 제거
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Oct 30, 2019
1 parent cd841c4 commit c3be3bd
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 71 deletions.
File renamed without changes.
3 changes: 1 addition & 2 deletions config/generate.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
save_path: 'bin/generated.mid'
condition_file: 'dataset/midi/BENABD10.mid'
condition_file:
length: 2048
load_path: None
2 changes: 1 addition & 1 deletion config/large.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
experiment: 'mt_large'
experiment: 'embedding512-layer6'
max_seq: 2048
embedding_dim: 512
num_layers: 6
Expand Down
4 changes: 2 additions & 2 deletions config/train.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pickle_dir: 'MusicTransformer/dataset/processed'
epochs: 100
pickle_dir: '../MusicTransformer/dataset/processed'
epochs: 1000
batch_size: 8
load_path:
dropout: 0.1
Expand Down
5 changes: 3 additions & 2 deletions custom/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def __init__(self, ignore_index=-100, reduction='mean') -> None:
super().__init__(reduction='none')

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
mask = (target != self.ignore_index).to(input.device, dtype=torch.float32)
target = target.to(torch.long)
mask = (target != self.ignore_index).to(input.device, dtype=torch.long)
not_masked_length = mask.to(torch.int).sum()
input = input.permute(0, -1, -2)
_loss = super().forward(input, target)
_loss *= mask
_loss *= mask.to(_loss.dtype)
return _loss.sum() / not_masked_length

def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Expand Down
21 changes: 8 additions & 13 deletions custom/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(self, num_layers, d_model, input_vocab_size, rate=0.1, max_len=None

self.enc_layers = torch.nn.ModuleList(
[EncoderLayer(d_model, rate, h=self.d_model // 64, additional=False, max_seq=max_len)
for i in range(num_layers)])
for _ in range(num_layers)])
self.dropout = torch.nn.Dropout(rate)

def forward(self, x, mask=None):
Expand All @@ -234,15 +234,10 @@ def forward(self, x, mask=None):
return x, weights # (batch_size, input_seq_len, d_model)


class MusicTransformerDataParallel(torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)

def forward(self, *inputs, **kwargs):
try:
return super().forward(*inputs)
except NotImplementedError:
return self.module(*inputs)
# class MusicTransformerDataParallelCriterion(torch.nn.DataParallel):
# def forward(self, inputs, *targets, **kwargs):
# targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
# replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
# targets = tuple(targets_per_gpu[0] for targets_per_gpu in targets)
# outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
# return Reduce.apply(*outputs) / len(outputs), targets
16 changes: 15 additions & 1 deletion custom/metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from custom.parallel import DataParallelCriterion

import torch
import numpy as np
import torch.nn.functional as F
Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(self, vocab_size):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
return input.argmax(-1).flatten().to(torch.int32).cpu()
return input.argmax(-1).flatten().to(torch.int32)


class MetricsSet(object):
Expand All @@ -73,6 +75,18 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
for k, metric in self.metrics.items()}


class ParallelMetricSet(MetricsSet):
def __init__(self, metric_dict: Dict):
super(ParallelMetricSet, self).__init__(metric_dict)
self.metrics = {k: DataParallelCriterion(v) for k, v in metric_dict.items()}

def forward(self, input, target):
# return [metric(input, target) for metric in self.metrics]
return {
k: metric(input, target)
for k, metric in self.metrics.items()}


if __name__ == '__main__':
met = MockAccuracy()
test_tensor1 = torch.ones((3,2)).contiguous().cuda().to(non_blocking=True, dtype=torch.int)
Expand Down
41 changes: 4 additions & 37 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def _get_seq(self, fname, max_length=None):
start = random.randrange(0,len(data) - max_length)
data = data[start:start + max_length]
else:
data = np.append(data, config.token_eos)
while len(data) < max_length:
data = np.append(data, config.pad_token)
raise IndexError
# data = np.append(data, config.token_eos)
# while len(data) < max_length:
# data = np.append(data, config.pad_token)
return data


Expand Down Expand Up @@ -134,37 +135,3 @@ def count_dict(max_length, data):
except KeyError:
cnt_dict['index-'+str(index)] = 1
return cnt_arr

# print(add_noise(np.array([[1,2,3,3,4,5,6]]), rate=0.2))


# print(par.vocab_size)
# data = Data('dataset/processed')
# # ds = DataSequence('dataset/processed', 10, 2048)
# sample = data.seq2seq_batch(1000, 100)[0]
# pprint.pprint(list(sample))
# arr = count_dict(par.vocab_size+3,sample)
# pprint.pprint(
# arr)
#
# from sequence import EventSeq, Event
#
# event_cnt = {
# 'note_on': 0,
# 'note_off': 0,
# 'velocity': 0,
# 'time_shift': 0
# }
# for event_index in range(len(arr)):
# for event_type, feat_range in EventSeq.feat_ranges().items():
#
# if feat_range.start <= event_index < feat_range.stop:
# print(event_type+':'+str(arr[event_index])+' event cnt: '+str(event_cnt))
# event_cnt[event_type] += arr[event_index]
#
# print(event_cnt)

# print(np.max(sample), np.min(sample))
# print([data._get_seq(file).shape for file in data.files])
#while True:
# print(ds.__getitem__(10)[1].argmax(-1))
8 changes: 4 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@
max_seq=config.max_seq,
dropout=0,
debug=False)
mt = torch.load_state_dict(args.model_dir+'/final.pth')
mt.eval()
mt.load_state_dict(torch.load(args.model_dir+'/final.pth'))
mt.test()

print(config.condition_file)
if config.condition_file is not None:
inputs = np.array([encode_midi('dataset/midi/BENABD10.mid')[:500]])
else:
inputs = np.array([[28]])
inputs = torch.from_numpy([inputs]).to(config.device)

inputs = torch.from_numpy(inputs)
result = mt(inputs, config.length, gen_summary_writer)

for i in result:
Expand Down
18 changes: 11 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MusicTransformer(torch.nn.Module):
def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
max_seq=2048, dropout=0.2, debug=False, loader_path=None, dist=False, writer=None):
super().__init__()

self.infer = False
if loader_path is not None:
self.load_config_file(loader_path)
else:
Expand All @@ -36,21 +36,21 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)

def forward(self, x, length=None, writer=None):
if self.training or self.eval:
if self.training or not self.infer:
_, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token)
decoder, w = self.Decoder(x, mask=look_ahead_mask)
fc = self.fc(decoder)
return fc.contiguous(), [weight.contiguous() for weight in w]
return fc.contiguous() if self.training else fc.contiguous(), [weight.contiguous() for weight in w]
else:
return self.generate(self.Decoder, x, length, writer).contiguous()

def generate(self, decode_fn, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWriter = None):
decode_array = prior
for i in Bar('generating').iter(range(min(self.max_seq, length))):
if decode_array.shape[1] >= self.max_seq:
if decode_array.size(1) >= self.max_seq:
break
_, _, look_ahead_mask = \
utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array)
utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token)

# result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
result, _ = decode_fn(decode_array, look_ahead_mask)
Expand All @@ -66,9 +66,13 @@ def generate(self, decode_fn, prior: torch.Tensor, length=2048, tf_board_writer:
decode_array = torch.cat([decode_array, result.unsqueeze(-1)], -1)
else:
pdf = dist.OneHotCategorical(probs=result[:, -1])
result = pdf.sample(1)
result = torch.transpose(result, 1, 0).to(torch.int32)
result = pdf.sample().argmax(-1).unsqueeze(0)
# result = torch.transpose(result, 1, 0).to(torch.int32)
decode_array = torch.cat((decode_array, result), dim=-1)
del look_ahead_mask
decode_array = decode_array[0]
return decode_array

def test(self):
self.eval()
self.infer = True
1 change: 0 additions & 1 deletion music-transformer-train.ipynb

This file was deleted.

2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@

start_time = time.time()
mt.train()
sample, _ = mt.forward(batch_x)
sample = mt.forward(batch_x)
metrics = metric_set(sample, batch_y)
loss = metrics['loss']
loss.backward()
Expand Down

0 comments on commit c3be3bd

Please sign in to comment.