diff --git a/custom/criterion.py b/custom/criterion.py index b0bae09..d80f5de 100644 --- a/custom/criterion.py +++ b/custom/criterion.py @@ -1,74 +1,87 @@ from typing import Optional, Any -import params as par -import sys - -from torch.__init__ import Tensor import torch -from torch.nn.modules.loss import CrossEntropyLoss +from torch.nn.modules.loss import CrossEntropyLoss, _Loss # from tensorflow.python.keras.optimizer_v2.learning_rate_schedule import LearningRateSchedule -class MTFitCallback(keras.callbacks.Callback): - - def __init__(self, save_path): - super(MTFitCallback, self).__init__() - self.save_path = save_path - - def on_epoch_end(self, epoch, logs=None): - self.model.save(self.save_path) - - class TransformerLoss(CrossEntropyLoss): def __init__(self, weight: Optional[Any] = ..., ignore_index: int = ..., reduction: str = ...) -> None: self.reduction = reduction + self.ignore_index = ignore_index super().__init__(weight, ignore_index, 'none') - def forward(self, input: Tensor, target: Tensor) -> Tensor: - mask = target != par.pad_token + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + mask = target != self.ignore_index not_masked_length = mask.to(torch.int).sum() input = input.permute(0, -1, -2) _loss = super().forward(input, target) _loss *= mask return _loss.sum() / not_masked_length - def __call__(self, input: Tensor, target: Tensor) -> Tensor: + def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return self.forward(input, target) -def transformer_dist_train_loss(y_true, y_pred): - y_true = tf.cast(y_true, tf.int32) - mask = tf.math.logical_not(tf.math.equal(y_true, par.pad_token)) - mask = tf.cast(mask, tf.float32) - - y_true_vector = tf.one_hot(y_true, par.vocab_size) - - _loss = tf.nn.softmax_cross_entropy_with_logits(y_true_vector, y_pred) - # print(_loss.shape) - # - # _loss = tf.reduce_mean(_loss, -1) - _loss *= mask - - return _loss - - -class CustomSchedule(LearningRateSchedule): - def __init__(self, d_model, warmup_steps=4000): - super(CustomSchedule, self).__init__() - - self.d_model = d_model - self.d_model = tf.cast(self.d_model, tf.float32) - - self.warmup_steps = warmup_steps - - def get_config(self): - super(CustomSchedule, self).get_config() - - def __call__(self, step): - arg1 = tf.math.rsqrt(step) - arg2 = step * (self.warmup_steps ** -1.5) - - return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) +class SmoothCrossEntropyLoss(_Loss): + """ + https://arxiv.org/abs/1512.00567 + """ + __constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction'] + + def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean'): + assert 0.0 <= label_smoothing <= 1.0 + super().__init__(reduction=reduction) + + self.label_smoothing = label_smoothing + self.vocab_size = vocab_size + self.ignore_index = ignore_index + + def forward(self, input, target): + """ + Args: + input: [B * T, V] + target: [B * T] + Returns: + cross entropy: [1] + """ + mask = (target == self.ignore_index).unsqueeze(1) + + q = torch.nn.functional.one_hot(target, self.vocab_size).type(torch.float32) + u = 1.0 / self.vocab_size + q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u + q_prime = q_prime.masked_fill(mask, 0) + + ce = self.cross_entropy_with_logits(q_prime, input) + if self.reduction == 'mean': + lengths = torch.sum(target != self.ignore_index) + return ce.sum() / lengths + elif self.reduction == 'sum': + return ce.sum() + else: + raise NotImplementedError + + def cross_entropy_with_logits(self, p, q): + return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1) + + +# class CustomSchedule(LearningRateSchedule): +# def __init__(self, d_model, warmup_steps=4000): +# super(CustomSchedule, self).__init__() +# +# self.d_model = d_model +# self.d_model = tf.cast(self.d_model, tf.float32) +# +# self.warmup_steps = warmup_steps +# +# def get_config(self): +# super(CustomSchedule, self).get_config() +# +# def __call__(self, step): +# arg1 = tf.math.rsqrt(step) +# arg2 = step * (self.warmup_steps ** -1.5) +# +# return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) if __name__ == '__main__': diff --git a/custom/layers.py b/custom/layers.py index f18e2c2..5f0bd49 100644 --- a/custom/layers.py +++ b/custom/layers.py @@ -56,7 +56,7 @@ def __init__(self, h=4, d=256, add_emb=False, max_seq=2048, **kwargs): self.Wv = torch.nn.Linear(self.d, self.d) self.fc = torch.nn.Linear(d, d) self.additional = add_emb - self.E = self.add_weight('emb', shape=[self.max_seq, int(self.dh)]) + self.E = torch.randn([self.max_seq, int(self.dh)], requires_grad=True) if self.additional: self.Radd = None diff --git a/data.py b/data.py index e9b238c..ab51f78 100644 --- a/data.py +++ b/data.py @@ -1,9 +1,9 @@ import utils import random import pickle -from tensorflow.python import keras import numpy as np -import params as par + +from custom.config import config class Data: @@ -84,9 +84,9 @@ def _get_seq(self, fname, max_length=None): start = random.randrange(0,len(data) - max_length) data = data[start:start + max_length] else: - data = np.append(data, par.token_eos) + data = np.append(data, config.token_eos) while len(data) < max_length: - data = np.append(data, par.pad_token) + data = np.append(data, config.pad_token) return data @@ -111,7 +111,7 @@ def add_noise(inputs: np.array, rate:float = 0.01): # input's dim is 2 num_mask = int(rate * seq_length) for inp in inputs: rand_idx = random.sample(range(seq_length), num_mask) - inp[rand_idx] = random.randrange(0, par.pad_token) + inp[rand_idx] = random.randrange(0, config.pad_token) return inputs diff --git a/generate.py b/generate.py index 5d11cf2..87b72a7 100644 --- a/generate.py +++ b/generate.py @@ -1,4 +1,5 @@ from custom.layers import * +import custom from custom import criterion from data import Data from custom.config import config @@ -10,9 +11,8 @@ from tensorboardX import SummaryWriter -parser = argparse.ArgumentParser() -args = parser.parse_args() - +parser = custom.get_argument_parser() +args=parser.parse_args() config.load(args.model_dir, args.configs, initialize=True) # check cuda diff --git a/model.py b/model.py index daf3611..7de1351 100644 --- a/model.py +++ b/model.py @@ -9,7 +9,7 @@ import utils import torch -from torch.utils.tensorboard import SummaryWriter +from tensorboardX import SummaryWriter from progress.bar import Bar @@ -34,8 +34,6 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq) self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size) - self._set_metrics() - def forward(self, x, lookup_mask=None): decoder, w = self.Decoder(x, mask=lookup_mask) fc = self.fc(decoder) diff --git a/train.py b/train.py index cb08d23..906b668 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,8 @@ from model import MusicTransformer +import custom from custom.metrics import * from custom.layers import * -from custom.criterion import TransformerLoss +from custom.criterion import SmoothCrossEntropyLoss from custom.config import config from data import Data @@ -9,12 +10,13 @@ import argparse import datetime +import torch import torch.optim as optim -from torch.utils.tensorboard import SummaryWriter +from tensorboardX import SummaryWriter # set config -parser = argparse.ArgumentParser() +parser = custom.get_argument_parser() args = parser.parse_args() config.load(args.model_dir, args.configs, initialize=True) @@ -37,11 +39,16 @@ mt = MusicTransformer( embedding_dim=config.embedding_dim, vocab_size=config.vocab_size, - num_layer=config.num_layer, + num_layer=config.num_layers, max_seq=config.max_seq, dropout=config.dropout, debug=config.debug, loader_path=config.load_path ) +opt = optim.Adam(mt.parameters(), lr=config.l_r) +metric_set = MetricsSet({ + 'accuracy': Accuracy(), + 'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token) +}) # multi-GPU set if torch.cuda.device_count() > 1: @@ -50,10 +57,6 @@ else: single_mt = mt -criterion = TransformerLoss -opt = optim.Adam(mt.parameters(), lr=config.l_r) -metric_set = MetricsSet({'accuracy': Accuracy(), 'loss': TransformerLoss()}) - # define tensorboard writer current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') train_log_dir = 'logs/mt_decoder/'+current_time+'/train' @@ -88,7 +91,7 @@ eval_preiction, weights = mt.teacher_forcing_forward(eval_x) eval_metrics = metric_set(eval_preiction, eval_y) - torch.save(single_mt, config.save_path) + torch.save(single_mt, config.model_dir+'train-{}.pth'.format(idx)) if b == 0: train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e) train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e) diff --git a/utils.py b/utils.py index c707cd6..01b3169 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,7 @@ from deprecated.sequence import EventSeq, ControlSeq import torch import torch.nn.functional as F -from custom.config import config +# from custom.config import config def find_files_by_extensions(root, exts=[]): @@ -54,20 +54,21 @@ def compute_gradient_norm(parameters, norm_type=2): return total_norm -def get_masked_with_pad_tensor(size, src, trg): +def get_masked_with_pad_tensor(size, src, trg, pad_token): """ :param size: the size of target input :param src: source tensor :param trg: target tensor + :param pad_token: pad token :return: """ src = src[:, None, None, :] trg = trg[:, None, None, :] - src_pad_tensor = torch.ones_like(src) * config.pad_token + src_pad_tensor = torch.ones_like(src) * pad_token src_mask = torch.equal(src, src_pad_tensor) trg_mask = torch.equal(src, src_pad_tensor) if trg is not None: - trg_pad_tensor = torch.ones_like(trg) * config.pad_token + trg_pad_tensor = torch.ones_like(trg) * pad_token dec_trg_mask = trg == trg_pad_tensor # boolean reversing i.e) True * -1 + 1 = False seq_mask = sequence_mask(torch.arange(1, size+1), size) * -1 + 1 @@ -89,12 +90,12 @@ def get_mask_tensor(size): return seq_mask -def fill_with_placeholder(prev_data: list, max_len: int, fill_val: float=config.pad_token): +def fill_with_placeholder(prev_data: list, max_len: int, fill_val: float): placeholder = [fill_val for _ in range(max_len - len(prev_data))] return prev_data + placeholder -def pad_with_length(max_length: int, seq: list, pad_val: float=config.pad_token): +def pad_with_length(max_length: int, seq: list, pad_val: float): """ :param max_length: max length of token :param seq: token list with shape:(length, dim) @@ -106,9 +107,9 @@ def pad_with_length(max_length: int, seq: list, pad_val: float=config.pad_token) return seq + pad -def append_token(data: torch.Tensor): - start_token = torch.ones((data.size(0), 1), dtype=data.dtype) * config.token_sos - end_token = torch.ones((data.size(0), 1), dtype=data.dtype) * config.token_eos +def append_token(data: torch.Tensor, eos_token): + start_token = torch.ones((data.size(0), 1), dtype=data.dtype) * eos_token + end_token = torch.ones((data.size(0), 1), dtype=data.dtype) * eos_token return torch.cat([start_token, data, end_token], -1)