diff --git a/custom/criterion.py b/custom/criterion.py index 6ab6b2c..b0bae09 100644 --- a/custom/criterion.py +++ b/custom/criterion.py @@ -27,6 +27,7 @@ def __init__(self, weight: Optional[Any] = ..., ignore_index: int = ..., reducti def forward(self, input: Tensor, target: Tensor) -> Tensor: mask = target != par.pad_token not_masked_length = mask.to(torch.int).sum() + input = input.permute(0, -1, -2) _loss = super().forward(input, target) _loss *= mask return _loss.sum() / not_masked_length diff --git a/custom/metrics.py b/custom/metrics.py index 3e78406..95517f6 100644 --- a/custom/metrics.py +++ b/custom/metrics.py @@ -1,5 +1,7 @@ import torch -from typing import List +import torch.nn.functional as F + +from typing import Dict class _Metric(torch.nn.Module): @@ -7,29 +9,32 @@ def __init__(self): super().__init__() def forward(self, input: torch.Tensor, target: torch.Tensor): - pass + raise NotImplementedError() -class CategoricalAccuracy(_Metric): +class Accuracy(_Metric): def __init__(self): super().__init__() def forward(self, input: torch.Tensor, target: torch.Tensor): - pass + bool_acc = input == target + return bool_acc.sum() / bool_acc.numel() -class Accuracy(_Metric): +class CategoricalAccuracy(Accuracy): def __init__(self): super().__init__() def forward(self, input: torch.Tensor, target: torch.Tensor): - pass + categorical_input = input.argmax(-1) + return super().forward(categorical_input, target) class MetricsSet(_Metric): - def __init__(self, metrics: List[_Metric]): + def __init__(self, metric_dict: Dict): super().__init__() - self.metrics = metrics + self.metrics = metric_dict def forward(self, input: torch.Tensor, target: torch.Tensor): - return [metric(input, target) for metric in self.metrics] \ No newline at end of file + # return [metric(input, target) for metric in self.metrics] + return {k: metric(input, target) for k, metric in self.metrics.items()} diff --git a/dist_train.py b/deprecated/dist_train.py similarity index 100% rename from dist_train.py rename to deprecated/dist_train.py diff --git a/generate.py b/generate.py index 27f6a84..4192f02 100644 --- a/generate.py +++ b/generate.py @@ -35,18 +35,16 @@ gen_summary_writer = tf.summary.create_file_writer(gen_log_dir) -if mode == 'enc-dec': - print(">> generate with original seq2seq wise... beam size is {}".format(beam)) - mt = MusicTransformer( - embedding_dim=256, - vocab_size=par.vocab_size, - num_layer=6, - max_seq=2048, - dropout=0.2, - debug=False, loader_path=load_path) -else: - print(">> generate with decoder wise... beam size is {}".format(beam)) - mt = MusicTransformerDecoder(loader_path=load_path) +print(">> generate with original seq2seq wise... beam size is {}".format(beam)) +# mt = MusicTransformer( +# embedding_dim=256, +# vocab_size=par.vocab_size, +# num_layer=6, +# max_seq=2048, +# dropout=0.2, +# debug=False, loader_path=load_path) +mt = torch.load(load_path) +mt.eval() inputs = encode_midi('dataset/midi/BENABD10.mid') diff --git a/model.py b/model.py index 27ed330..73d91a2 100644 --- a/model.py +++ b/model.py @@ -6,12 +6,10 @@ import sys import torch import torch.distributions as dist -import json import random import utils import torch -import torch.functional as F class MusicTransformer(torch.nn.Module): @@ -40,19 +38,13 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, def forward(self, x, lookup_mask=None): decoder, w = self.Decoder(x, mask=lookup_mask) fc = self.fc(decoder) + fc = fc.softmax(-1) return fc, w - # if self.training: - # return fc - # elif eval: - # return fc, w - # else: - # return F.softmax(fc) def generate(self, prior: list, length=2048, tf_board=False): decode_array = np.array([prior]) decode_array = torch.from_numpy(decode_array) for i in range(min(self.max_seq, length)): - # print(decode_array.shape[1]) if decode_array.shape[1] >= self.max_seq: break if i % 100 == 0: @@ -64,7 +56,7 @@ def generate(self, prior: list, length=2048, tf_board=False): u = random.uniform(0, 1) if u > 1: - result = F.argmax(result[:, -1], -1).to(torch.int32) + result = result[:, -1].argmax(-1).to(torch.int32) decode_array = torch.cat([decode_array, result.unsqueeze(-1)], -1) else: pdf = dist.OneHotCategorical(probs=result[:, -1]) @@ -75,264 +67,267 @@ def generate(self, prior: list, length=2048, tf_board=False): decode_array = decode_array[0] return decode_array - def train_forward(self, x): + def teacher_forcing_forward(self, x, attn=False): x, _ = self.__prepare_train_data(x, x) _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) - predictions, _ = self.forward( + predictions, w = self.forward( x, lookup_mask=look_ahead_mask, ) if self._debug: print('train step finished') - return predictions - - -class MusicTransformerDecoder(torch.nn.Module): - def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, - max_seq=2048, dropout=0.2, debug=False, loader_path=None, dist=False): - super(MusicTransformerDecoder, self).__init__() - - if loader_path is not None: - self.load_config_file(loader_path) - else: - self._debug = debug - self.max_seq = max_seq - self.num_layer = num_layer - self.embedding_dim = embedding_dim - self.vocab_size = vocab_size - self.dist = dist - - self.Decoder = Encoder( - num_layers=self.num_layer, d_model=self.embedding_dim, - input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq) - self.fc = keras.layers.Dense(self.vocab_size, activation=None, name='output') - - self._set_metrics() - - if loader_path is not None: - self.load_ckpt_file(loader_path) - - def call(self, inputs, training=None, eval=None, lookup_mask=None): - decoder, w = self.Decoder(inputs, training=training, mask=lookup_mask) - fc = self.fc(decoder) - if training: - return fc - elif eval: - return fc, w - else: - return tf.nn.softmax(fc) - - def train_on_batch(self, x, y=None, sample_weight=None, class_weight=None, reset_metrics=True): - if self._debug: - tf.print('sanity:\n', self.sanity_check(x, y, mode='d'), output_stream=sys.stdout) - - x, y = self.__prepare_train_data(x, y) - - _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) - - if self.dist: - predictions = self.__dist_train_step( - x, y, look_ahead_mask, True) + if attn: + return predictions, w else: - predictions = self.__train_step(x, y, look_ahead_mask, True) - - if self._debug: - print('train step finished') - result_metric = [] - - if self.dist: - loss = self._distribution_strategy.reduce(tf.distribute.ReduceOp.MEAN, self.loss_value, None) - else: - loss = tf.reduce_mean(self.loss_value) - loss = tf.reduce_mean(loss) - for metric in self.custom_metrics: - result_metric.append(metric(y, predictions).numpy()) - - return [loss.numpy()]+result_metric - - def __dist_train_step(self, inp_tar, out_tar, lookup_mask, training): - return self._distribution_strategy.experimental_run_v2( - self.__train_step, args=(inp_tar, out_tar, lookup_mask, training)) - - def __train_step(self, inp_tar, out_tar, lookup_mask, training): - with tf.GradientTape() as tape: - predictions = self.call( - inputs=inp_tar, lookup_mask=lookup_mask, training=training - ) - self.loss_value = self.loss(out_tar, predictions) - gradients = tape.gradient(self.loss_value, self.trainable_variables) - self.grad = gradients - self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) - - return predictions - - def evaluate(self, x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None, callbacks=None, - max_queue_size=10, workers=1, use_multiprocessing=False): - - # x, inp_tar, out_tar = MusicTransformer.__prepare_train_data(x, y) - _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) - predictions, w = self.call( - x, lookup_mask=look_ahead_mask, training=False, eval=True) - loss = tf.reduce_mean(self.loss(y, predictions)) - result_metric = [] - for metric in self.custom_metrics: - result_metric.append(metric(y, tf.nn.softmax(predictions)).numpy()) - return [loss.numpy()] + result_metric, w - - def save(self, filepath, overwrite=True, include_optimizer=False, save_format=None): - config_path = filepath+'/'+'config.json' - ckpt_path = filepath+'/ckpt' - - self.save_weights(ckpt_path, save_format='tf') - with open(config_path, 'w') as f: - json.dump(self.get_config(), f) - return - - def load_config_file(self, filepath): - config_path = filepath + '/' + 'config.json' - with open(config_path, 'r') as f: - config = json.load(f) - self.__load_config(config) - - def load_ckpt_file(self, filepath, ckpt_name='ckpt'): - ckpt_path = filepath + '/' + ckpt_name - try: - self.load_weights(ckpt_path) - except FileNotFoundError: - print("[Warning] model will be initialized...") - - def sanity_check(self, x, y, mode='v', step=None): - # mode: v -> vector, d -> dict - # x, inp_tar, out_tar = self.__prepare_train_data(x, y) - - _, tar_mask, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) - predictions = self.call( - x, lookup_mask=look_ahead_mask, training=False) - - if mode == 'v': - tf.summary.image('vector', tf.expand_dims(predictions, -1), step) return predictions - elif mode == 'd': - dic = {} - for row in tf.argmax(predictions, -1).numpy(): - for col in row: - try: - dic[str(col)] += 1 - except KeyError: - dic[str(col)] = 1 - return dic - else: - tf.summary.image('tokens', tf.argmax(predictions, -1), step) - return tf.argmax(predictions, -1) - - def get_config(self): - config = {} - config['debug'] = self._debug - config['max_seq'] = self.max_seq - config['num_layer'] = self.num_layer - config['embedding_dim'] = self.embedding_dim - config['vocab_size'] = self.vocab_size - config['dist'] = self.dist - return config - - def generate(self, prior: list, beam=None, length=2048, tf_board=False): - decode_array = prior - decode_array = tf.constant([decode_array]) - # TODO: add beam search - if beam is not None: - k = beam - for i in range(min(self.max_seq, length)): - if decode_array.shape[1] >= self.max_seq: - break - if i % 100 == 0: - print('generating... {}% completed'.format((i/min(self.max_seq, length))*100)) - _, _, look_ahead_mask = \ - utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array) - result = self.call(decode_array, lookup_mask=look_ahead_mask, training=False, eval=False) - if tf_board: - tf.summary.image('generate_vector', tf.expand_dims([result[0]], -1), i) - - result = result[:,-1,:] - result = tf.reshape(result, (1, -1)) - result, result_idx = tf.nn.top_k(result, k) - row = result_idx // par.vocab_size - col = result_idx % par.vocab_size - - result_array = [] - for r, c in zip(row[0], col[0]): - prev_array = decode_array[r.numpy()] - result_unit = tf.concat([prev_array, [c.numpy()]], -1) - result_array.append(result_unit.numpy()) - # result_array.append(tf.concat([decode_array[idx], result[:,idx_idx]], -1)) - decode_array = tf.constant(result_array) - del look_ahead_mask - decode_array = decode_array[0] - - else: - for i in range(min(self.max_seq, length)): - # print(decode_array.shape[1]) - if decode_array.shape[1] >= self.max_seq: - break - if i % 100 == 0: - print('generating... {}% completed'.format((i/min(self.max_seq, length))*100)) - _, _, look_ahead_mask = \ - utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array) - - result = self.call(decode_array, lookup_mask=look_ahead_mask, training=False) - if tf_board: - tf.summary.image('generate_vector', tf.expand_dims(result, -1), i) - # import sys - # tf.print('[debug out:]', result, sys.stdout ) - u = random.uniform(0, 1) - if u > 1: - result = tf.argmax(result[:, -1], -1) - result = tf.cast(result, tf.int32) - decode_array = tf.concat([decode_array, tf.expand_dims(result, -1)], -1) - else: - pdf = tfp.distributions.Categorical(probs=result[:, -1]) - result = pdf.sample(1) - result = tf.transpose(result, (1, 0)) - result = tf.cast(result, tf.int32) - decode_array = tf.concat([decode_array, result], -1) - # decode_array = tf.concat([decode_array, tf.expand_dims(result[:, -1], 0)], -1) - del look_ahead_mask - decode_array = decode_array[0] - - return decode_array.numpy() - - def _set_metrics(self): - accuracy = keras.metrics.SparseCategoricalAccuracy() - self.custom_metrics = [accuracy] - - def __load_config(self, config): - self._debug = config['debug'] - self.max_seq = config['max_seq'] - self.num_layer = config['num_layer'] - self.embedding_dim = config['embedding_dim'] - self.vocab_size = config['vocab_size'] - self.dist = config['dist'] - - def reset_metrics(self): - for metric in self.custom_metrics: - metric.reset_states() - return - - @staticmethod - def __prepare_train_data(x, y): - # start_token = tf.ones((y.shape[0], 1), dtype=y.dtype) * par.token_sos - # end_token = tf.ones((y.shape[0], 1), dtype=y.dtype) * par.token_eos - - # # method with eos - # out_tar = tf.concat([y[:, :-1], end_token], -1) - # inp_tar = tf.concat([start_token, y[:, :-1]], -1) - # x = tf.concat([start_token, x[:, 2:], end_token], -1) - - # method without eos - # x = data.add_noise(x, rate=0.01) - return x, y +# class MusicTransformerDecoder(torch.nn.Module): +# def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, +# max_seq=2048, dropout=0.2, debug=False, loader_path=None, dist=False): +# super(MusicTransformerDecoder, self).__init__() +# +# if loader_path is not None: +# self.load_config_file(loader_path) +# else: +# self._debug = debug +# self.max_seq = max_seq +# self.num_layer = num_layer +# self.embedding_dim = embedding_dim +# self.vocab_size = vocab_size +# self.dist = dist +# +# self.Decoder = Encoder( +# num_layers=self.num_layer, d_model=self.embedding_dim, +# input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq) +# self.fc = keras.layers.Dense(self.vocab_size, activation=None, name='output') +# +# self._set_metrics() +# +# if loader_path is not None: +# self.load_ckpt_file(loader_path) +# +# def call(self, inputs, training=None, eval=None, lookup_mask=None): +# decoder, w = self.Decoder(inputs, training=training, mask=lookup_mask) +# fc = self.fc(decoder) +# if training: +# return fc +# elif eval: +# return fc, w +# else: +# return tf.nn.softmax(fc) +# +# def train_on_batch(self, x, y=None, sample_weight=None, class_weight=None, reset_metrics=True): +# if self._debug: +# tf.print('sanity:\n', self.sanity_check(x, y, mode='d'), output_stream=sys.stdout) +# +# x, y = self.__prepare_train_data(x, y) +# +# _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) +# +# if self.dist: +# predictions = self.__dist_train_step( +# x, y, look_ahead_mask, True) +# else: +# predictions = self.__train_step(x, y, look_ahead_mask, True) +# +# if self._debug: +# print('train step finished') +# result_metric = [] +# +# if self.dist: +# loss = self._distribution_strategy.reduce(tf.distribute.ReduceOp.MEAN, self.loss_value, None) +# else: +# loss = tf.reduce_mean(self.loss_value) +# loss = tf.reduce_mean(loss) +# for metric in self.custom_metrics: +# result_metric.append(metric(y, predictions).numpy()) +# +# return [loss.numpy()]+result_metric +# +# def __dist_train_step(self, inp_tar, out_tar, lookup_mask, training): +# return self._distribution_strategy.experimental_run_v2( +# self.__train_step, args=(inp_tar, out_tar, lookup_mask, training)) +# +# def __train_step(self, inp_tar, out_tar, lookup_mask, training): +# with tf.GradientTape() as tape: +# predictions = self.call( +# inputs=inp_tar, lookup_mask=lookup_mask, training=training +# ) +# self.loss_value = self.loss(out_tar, predictions) +# gradients = tape.gradient(self.loss_value, self.trainable_variables) +# self.grad = gradients +# self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) +# +# return predictions +# +# def evaluate(self, x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None, callbacks=None, +# max_queue_size=10, workers=1, use_multiprocessing=False): +# +# # x, inp_tar, out_tar = MusicTransformer.__prepare_train_data(x, y) +# _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) +# predictions, w = self.call( +# x, lookup_mask=look_ahead_mask, training=False, eval=True) +# loss = tf.reduce_mean(self.loss(y, predictions)) +# result_metric = [] +# for metric in self.custom_metrics: +# result_metric.append(metric(y, tf.nn.softmax(predictions)).numpy()) +# return [loss.numpy()] + result_metric, w +# +# def save(self, filepath, overwrite=True, include_optimizer=False, save_format=None): +# config_path = filepath+'/'+'config.json' +# ckpt_path = filepath+'/ckpt' +# +# self.save_weights(ckpt_path, save_format='tf') +# with open(config_path, 'w') as f: +# json.dump(self.get_config(), f) +# return +# +# def load_config_file(self, filepath): +# config_path = filepath + '/' + 'config.json' +# with open(config_path, 'r') as f: +# config = json.load(f) +# self.__load_config(config) +# +# def load_ckpt_file(self, filepath, ckpt_name='ckpt'): +# ckpt_path = filepath + '/' + ckpt_name +# try: +# self.load_weights(ckpt_path) +# except FileNotFoundError: +# print("[Warning] model will be initialized...") +# +# def sanity_check(self, x, y, mode='v', step=None): +# # mode: v -> vector, d -> dict +# # x, inp_tar, out_tar = self.__prepare_train_data(x, y) +# +# _, tar_mask, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x) +# predictions = self.call( +# x, lookup_mask=look_ahead_mask, training=False) +# +# if mode == 'v': +# tf.summary.image('vector', tf.expand_dims(predictions, -1), step) +# return predictions +# elif mode == 'd': +# dic = {} +# for row in tf.argmax(predictions, -1).numpy(): +# for col in row: +# try: +# dic[str(col)] += 1 +# except KeyError: +# dic[str(col)] = 1 +# return dic +# else: +# tf.summary.image('tokens', tf.argmax(predictions, -1), step) +# return tf.argmax(predictions, -1) +# +# def get_config(self): +# config = {} +# config['debug'] = self._debug +# config['max_seq'] = self.max_seq +# config['num_layer'] = self.num_layer +# config['embedding_dim'] = self.embedding_dim +# config['vocab_size'] = self.vocab_size +# config['dist'] = self.dist +# return config +# +# def generate(self, prior: list, beam=None, length=2048, tf_board=False): +# decode_array = prior +# decode_array = tf.constant([decode_array]) +# +# # TODO: add beam search +# if beam is not None: +# k = beam +# for i in range(min(self.max_seq, length)): +# if decode_array.shape[1] >= self.max_seq: +# break +# if i % 100 == 0: +# print('generating... {}% completed'.format((i/min(self.max_seq, length))*100)) +# _, _, look_ahead_mask = \ +# utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array) +# +# result = self.call(decode_array, lookup_mask=look_ahead_mask, training=False, eval=False) +# if tf_board: +# tf.summary.image('generate_vector', tf.expand_dims([result[0]], -1), i) +# +# result = result[:,-1,:] +# result = tf.reshape(result, (1, -1)) +# result, result_idx = tf.nn.top_k(result, k) +# row = result_idx // par.vocab_size +# col = result_idx % par.vocab_size +# +# result_array = [] +# for r, c in zip(row[0], col[0]): +# prev_array = decode_array[r.numpy()] +# result_unit = tf.concat([prev_array, [c.numpy()]], -1) +# result_array.append(result_unit.numpy()) +# # result_array.append(tf.concat([decode_array[idx], result[:,idx_idx]], -1)) +# decode_array = tf.constant(result_array) +# del look_ahead_mask +# decode_array = decode_array[0] +# +# else: +# for i in range(min(self.max_seq, length)): +# # print(decode_array.shape[1]) +# if decode_array.shape[1] >= self.max_seq: +# break +# if i % 100 == 0: +# print('generating... {}% completed'.format((i/min(self.max_seq, length))*100)) +# _, _, look_ahead_mask = \ +# utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array) +# +# result = self.call(decode_array, lookup_mask=look_ahead_mask, training=False) +# if tf_board: +# tf.summary.image('generate_vector', tf.expand_dims(result, -1), i) +# # import sys +# # tf.print('[debug out:]', result, sys.stdout ) +# u = random.uniform(0, 1) +# if u > 1: +# result = tf.argmax(result[:, -1], -1) +# result = tf.cast(result, tf.int32) +# decode_array = tf.concat([decode_array, tf.expand_dims(result, -1)], -1) +# else: +# pdf = tfp.distributions.Categorical(probs=result[:, -1]) +# result = pdf.sample(1) +# result = tf.transpose(result, (1, 0)) +# result = tf.cast(result, tf.int32) +# decode_array = tf.concat([decode_array, result], -1) +# # decode_array = tf.concat([decode_array, tf.expand_dims(result[:, -1], 0)], -1) +# del look_ahead_mask +# decode_array = decode_array[0] +# +# return decode_array.numpy() +# +# def _set_metrics(self): +# accuracy = keras.metrics.SparseCategoricalAccuracy() +# self.custom_metrics = [accuracy] +# +# def __load_config(self, config): +# self._debug = config['debug'] +# self.max_seq = config['max_seq'] +# self.num_layer = config['num_layer'] +# self.embedding_dim = config['embedding_dim'] +# self.vocab_size = config['vocab_size'] +# self.dist = config['dist'] +# +# def reset_metrics(self): +# for metric in self.custom_metrics: +# metric.reset_states() +# return +# +# @staticmethod +# def __prepare_train_data(x, y): +# # start_token = tf.ones((y.shape[0], 1), dtype=y.dtype) * par.token_sos +# # end_token = tf.ones((y.shape[0], 1), dtype=y.dtype) * par.token_eos +# +# # # method with eos +# # out_tar = tf.concat([y[:, :-1], end_token], -1) +# # inp_tar = tf.concat([start_token, y[:, :-1]], -1) +# # x = tf.concat([start_token, x[:, 2:], end_token], -1) +# +# # method without eos +# # x = data.add_noise(x, rate=0.01) +# return x, y # # if __name__ == '__main__': diff --git a/train.py b/train.py index dea03f0..9d58a22 100644 --- a/train.py +++ b/train.py @@ -60,7 +60,7 @@ ) criterion = TransformerLoss opt = optim.Adam(mt.parameters(), lr=l_r) -metric_set = MetricsSet([Accuracy, ]) +metric_set = MetricsSet({'accuracy': Accuracy(), 'loss': TransformerLoss()}) # define tensorboard writer current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') @@ -70,12 +70,11 @@ train_summary_writer = SummaryWriter(train_log_dir) eval_summary_writer = SummaryWriter(eval_log_dir) - # Train Start idx = 0 -opt.zero_grad() for e in range(epochs): for b in range(len(dataset.files) // batch_size): + opt.zero_grad() try: batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq) batch_x = torch.from_numpy(batch_x) @@ -83,31 +82,33 @@ except: continue - sample = mt.train_forward(batch_x) - loss = criterion(sample, batch_y) + sample = mt.teacher_forcing_forward(batch_x) + metrics = metric_set(sample, batch_y) + loss = metrics['loss'] loss.backward() opt.step() - - result_metrics = metric_set(sample, batch_y) + + # result_metrics = metric_set(sample, batch_y) if b % 100 == 0: eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval') - eval_result_metrics, weights = mt.evaluate(eval_x, eval_y) - mt.save(save_path) + eval_preiction, weights = mt.teacher_forcing_forward(eval_x) + eval_metrics = metric_set(eval_preiction, eval_y) + torch.save(mt, save_path) if b == 0: train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e) train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e) - train_summary_writer.add_scalar('loss', result_metrics[0], global_step=idx) - train_summary_writer.add_scalar('accuracy', result_metrics[1], global_step=idx) + train_summary_writer.add_scalar('loss', metrics['loss'], global_step=idx) + train_summary_writer.add_scalar('accuracy', metrics['accuracy'], global_step=idx) - eval_summary_writer.add_scalar('loss', eval_result_metrics[0], global_step=idx) - eval_summary_writer.add_scalar('accuracy', eval_result_metrics[1], global_step=idx) + eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx) + eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx) for i, weight in enumerate(weights): attn_log_name = "attn/layer-{}".format(i) - utils.attention_image_summary(attn_log_name, step=idx) + utils.attention_image_summary(attn_log_name, weights, step=idx, writer=eval_summary_writer) idx += 1 print('\n====================================================') print('Epoch/Batch: {}/{}'.format(e, b)) - print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1])) - print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1])) + print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy'])) + print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['acccuracy'])) diff --git a/utils.py b/utils.py index 1693b7d..85999c0 100644 --- a/utils.py +++ b/utils.py @@ -102,7 +102,7 @@ def get_masked_with_pad_tensor(size, src, trg): trg_mask = torch.equal(src, src_pad_tensor) if trg is not None: trg_pad_tensor = torch.ones_like(trg) * par.pad_token - dec_trg_mask = torch.equal(trg, trg_pad_tensor) + dec_trg_mask = trg == trg_pad_tensor # boolean reversing i.e) True * -1 + 1 = False seq_mask = sequence_mask(torch.arange(1, size+1), size) * -1 + 1 look_ahead_mask = torch.max(dec_trg_mask, seq_mask) @@ -161,7 +161,7 @@ def shape_list(x): return res -def attention_image_summary(attn, step=0, writer=None): +def attention_image_summary(name, attn, step=0, writer=None): """Compute color image summary. Args: attn: a Tensor with shape [batch, num_heads, query_length, memory_length] @@ -183,7 +183,7 @@ def attention_image_summary(attn, step=0, writer=None): image = F.pad(image, [0, 0, 0, 0, 0, 0, 0, torch.fmod(-num_heads, 3)]) image = split_last_dimension(image, 3) image = torch.max(image, dim=4) - writer.add_image(attn, image, max_outputs=1, global_step=step) + writer.add_image(name, image, global_step=step, deformats='HWC') def split_last_dimension(x, n):