Skip to content

Commit

Permalink
텐서플로우 레거시 제거, requirements 추가, config 디렉터리 추가, metric 클래스 구현
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Oct 12, 2019
1 parent f9349b8 commit d6d7e6a
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 140 deletions.
11 changes: 11 additions & 0 deletions config/base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
max_seq: 2048
l_r: 0.001
embedding_dim: 256
num_attention_layer: 6
batch_size: 10
loss_type: 'categorical_crossentropy'
event_dim: 388
#pad_token: event_dim
##token_sos: event_dim + 1
##token_eos: event_dim + 2
##vocab_size: event_dim + 3
38 changes: 19 additions & 19 deletions custom/callback.py → custom/criterion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from tensorflow.python import keras
import tensorflow as tf
from typing import Optional, Any

import params as par
import sys
from tensorflow.python.keras.optimizer_v2.learning_rate_schedule import LearningRateSchedule

from torch.__init__ import Tensor
import torch
from torch.nn.modules.loss import CrossEntropyLoss
# from tensorflow.python.keras.optimizer_v2.learning_rate_schedule import LearningRateSchedule


class MTFitCallback(keras.callbacks.Callback):
Expand All @@ -15,24 +19,20 @@ def on_epoch_end(self, epoch, logs=None):
self.model.save(self.save_path)


class TransformerLoss(keras.losses.SparseCategoricalCrossentropy):
def __init__(self, from_logits=False, reduction='none', debug=False, **kwargs):
super(TransformerLoss, self).__init__(from_logits, reduction, **kwargs)
self.debug = debug
pass
class TransformerLoss(CrossEntropyLoss):
def __init__(self, weight: Optional[Any] = ..., ignore_index: int = ..., reduction: str = ...) -> None:
self.reduction = reduction
super().__init__(weight, ignore_index, 'none')

def call(self, y_true, y_pred):
y_true = tf.cast(y_true, tf.int32)
mask = tf.math.logical_not(tf.math.equal(y_true, par.pad_token))
mask = tf.cast(mask, tf.float32)
_loss = super(TransformerLoss, self).call(y_true, y_pred)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
mask = target != par.pad_token
not_masked_length = mask.to(torch.int).sum()
_loss = super().forward(input, target)
_loss *= mask
if self.debug:
tf.print('loss shape:', _loss.shape, output_stream=sys.stdout)
tf.print('output:', tf.argmax(y_pred,-1), output_stream=sys.stdout)
tf.print(mask, output_stream=sys.stdout)
tf.print(_loss, output_stream=sys.stdout)
return _loss
return _loss.sum() / not_masked_length

def __call__(self, input: Tensor, target: Tensor) -> Tensor:
return self.forward(input, target)


def transformer_dist_train_loss(y_true, y_pred):
Expand Down
5 changes: 2 additions & 3 deletions custom/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,11 @@ def _skewing(self, tensor: torch.Tensor):
padded = F.pad(tensor, [0, 0, 0, 0, 0, 0, 1, 0])
reshaped = torch.reshape(padded, shape=[-1, padded.size(1), padded.size(-1), padded.size(-2)])
Srel = reshaped[:, :, 1:, :]
# print('Sre: {}'.format(Srel))

if self.len_k > self.len_q:
Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q])
elif self.len_k < self.len_q:
Srel = Srel[:,:,:,:self.len_k]
Srel = Srel[:, :, :, :self.len_k]

return Srel

Expand Down Expand Up @@ -224,4 +223,4 @@ def call(self, x, mask=None):
for i in range(self.num_layers):
x, w = self.enc_layers[i](x, mask)
weights.append(w)
return x, weights # (batch_size, input_seq_len, d_model)
return x, weights # (batch_size, input_seq_len, d_model)
35 changes: 35 additions & 0 deletions custom/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from typing import List


class _Metric(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
pass


class CategoricalAccuracy(_Metric):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
pass


class Accuracy(_Metric):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
pass


class MetricsSet(_Metric):
def __init__(self, metrics: List[_Metric]):
super().__init__()
self.metrics = metrics

def forward(self, input: torch.Tensor, target: torch.Tensor):
return [metric(input, target) for metric in self.metrics]
6 changes: 3 additions & 3 deletions deprecated/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from model import MusicTransformer
from custom.layers import *
from custom import callback
from custom import criterion
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
Expand Down Expand Up @@ -44,7 +44,7 @@


# load model
learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r
learning_rate = criterion.CustomSchedule(par.embedding_dim) if l_r is None else l_r
opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)


Expand All @@ -56,7 +56,7 @@
max_seq=max_seq,
dropout=0.2,
debug=False, loader_path=load_path)
mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)
mt.compile(optimizer=opt, loss=criterion.transformer_dist_train_loss)


# define tensorboard writer
Expand Down
6 changes: 3 additions & 3 deletions dist_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from model import MusicTransformer
from custom.layers import *
from custom import callback
from custom import criterion
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
Expand Down Expand Up @@ -43,7 +43,7 @@


# load model
learning_rate = callback.CustomSchedule(par.embedding_dim)
learning_rate = criterion.CustomSchedule(par.embedding_dim)
opt = Adam(l_r, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

strategy = tf.distribute.MirroredStrategy()
Expand All @@ -58,7 +58,7 @@
max_seq=max_seq,
dropout=0.2,
debug=False, loader_path=load_path)
mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)
mt.compile(optimizer=opt, loss=criterion.transformer_dist_train_loss)

# Train Start
for e in range(epochs):
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from model import MusicTransformer, MusicTransformerDecoder
from custom.layers import *
from custom import callback
from custom import criterion
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
Expand Down
84 changes: 46 additions & 38 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from custom.layers import *
from custom.callback import *
from custom.criterion import *
from custom.layers import Encoder
import params as par

Expand All @@ -16,7 +16,7 @@

class MusicTransformer(torch.nn.Module):
def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
max_seq=2048, dropout=0.2, debug=False, loader_path=None, dist=False):
max_seq=2048, dropout=0.2, debug=False, loader_path=None, dist=False, writer=None):
super().__init__()

if loader_path is not None:
Expand All @@ -29,22 +29,24 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
self.vocab_size = vocab_size
self.dist = dist

self.writer = writer
self.Decoder = Encoder(
num_layers=self.num_layer, d_model=self.embedding_dim,
input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq)
self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)

self._set_metrics()

def forward(self, x):
def forward(self, x, lookup_mask=None):
decoder, w = self.Decoder(x, mask=lookup_mask)
fc = self.fc(decoder)
if self.training:
return fc
elif eval:
return fc, w
else:
return F.softmax(fc)
return fc, w
# if self.training:
# return fc
# elif eval:
# return fc, w
# else:
# return F.softmax(fc)

def generate(self, prior: list, length=2048, tf_board=False):
decode_array = np.array([prior])
Expand All @@ -58,25 +60,33 @@ def generate(self, prior: list, length=2048, tf_board=False):
_, _, look_ahead_mask = \
utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array)

result = self.call(decode_array, lookup_mask=look_ahead_mask, training=False)
if tf_board:
tf.summary.image('generate_vector', tf.expand_dims(result, -1), i)
# import sys
# tf.print('[debug out:]', result, sys.stdout )
result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)

u = random.uniform(0, 1)
if u > 1:
result = F.argmax(result[:, -1], -1).to(torch.int32)
decode_array = tf.concat([decode_array, tf.expand_dims(result, -1)], -1)
decode_array = torch.cat([decode_array, result.unsqueeze(-1)], -1)
else:
pdf = dist.OneHotCategorical(probs=result[:, -1])
result = pdf.sample(1)
result = torch.transpose(result, (1, 0)).to(torch.int32)
result = torch.transpose(result, 1, 0).to(torch.int32)
decode_array = torch.cat((decode_array, result), dim=-1)
# decode_array = tf.concat([decode_array, tf.expand_dims(result[:, -1], 0)], -1)
del look_ahead_mask
decode_array = decode_array[0]
return decode_array

def train_forward(self, x):
x, _ = self.__prepare_train_data(x, x)
_, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x)

predictions, _ = self.forward(
x, lookup_mask=look_ahead_mask,
)

if self._debug:
print('train step finished')
return predictions


class MusicTransformerDecoder(torch.nn.Module):
def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
Expand Down Expand Up @@ -141,12 +151,10 @@ def train_on_batch(self, x, y=None, sample_weight=None, class_weight=None, reset

return [loss.numpy()]+result_metric

# @tf.function
def __dist_train_step(self, inp_tar, out_tar, lookup_mask, training):
return self._distribution_strategy.experimental_run_v2(
self.__train_step, args=(inp_tar, out_tar, lookup_mask, training))

# @tf.function
def __train_step(self, inp_tar, out_tar, lookup_mask, training):
with tf.GradientTape() as tape:
predictions = self.call(
Expand Down Expand Up @@ -326,22 +334,22 @@ def __prepare_train_data(x, y):
# x = data.add_noise(x, rate=0.01)
return x, y


if __name__ == '__main__':
# import utils
print(tf.executing_eagerly())

src = tf.constant([utils.fill_with_placeholder([1,2,3,4],max_len=2048)])
trg = tf.constant([utils.fill_with_placeholder([1,2,3,4],max_len=2048)])
src_mask, trg_mask, lookup_mask = utils.get_masked_with_pad_tensor(2048, src,trg)
print(lookup_mask)
print(src_mask)
mt = MusicTransformer(debug=True, embedding_dim=par.embedding_dim, vocab_size=par.vocab_size)
mt.save_weights('my_model.h5', save_format='h5')
mt.load_weights('my_model.h5')
result = mt.generate([27, 186, 43, 213, 115, 131], length=100)
print(result)
from deprecated import sequence

sequence.EventSeq.from_array(result[0]).to_note_seq().to_midi_file('result.midi')
pass
#
# if __name__ == '__main__':
# # import utils
# print(tf.executing_eagerly())
#
# src = tf.constant([utils.fill_with_placeholder([1,2,3,4],max_len=2048)])
# trg = tf.constant([utils.fill_with_placeholder([1,2,3,4],max_len=2048)])
# src_mask, trg_mask, lookup_mask = utils.get_masked_with_pad_tensor(2048, src,trg)
# print(lookup_mask)
# print(src_mask)
# mt = MusicTransformer(debug=True, embedding_dim=par.embedding_dim, vocab_size=par.vocab_size)
# mt.save_weights('my_model.h5', save_format='h5')
# mt.load_weights('my_model.h5')
# result = mt.generate([27, 186, 43, 213, 115, 131], length=100)
# print(result)
# from deprecated import sequence
#
# sequence.EventSeq.from_array(result[0]).to_note_seq().to_midi_file('result.midi')
# pass
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
torch
tensorboardX
absl-py==0.7.1
alembic==1.0.11
appdirs==1.4.3
Expand Down Expand Up @@ -89,10 +91,7 @@ SQLAlchemy==1.3.5
sqlparse==0.3.0
ssh-import-id==5.7
tabulate==0.8.3
tb-nightly==1.14.0a20190603
tensorflow-gpu==2.0.0b1
termcolor==1.1.0
tf-estimator-nightly==1.14.0.dev2019060501
tfp-nightly==0.8.0.dev20190807
treelib==1.5.5
urllib3==1.22
Expand Down
Loading

0 comments on commit d6d7e6a

Please sign in to comment.