From 7f334ea4388e0d2a5e6dce5deea50605dc6d4f2f Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Wed, 4 Sep 2019 05:50:53 +0900 Subject: [PATCH] =?UTF-8?q?torch=20=EB=AA=A8=EB=93=88=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom/layers.py | 33 ++++++++++++++------------------- model.py | 3 ++- preprocess.py | 15 --------------- 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/custom/layers.py b/custom/layers.py index c582850..eda3bd0 100644 --- a/custom/layers.py +++ b/custom/layers.py @@ -35,7 +35,7 @@ def __init__(self, embedding_dim, max_seq=2048): self.positional_embedding = embed_sinusoid_list def forward(self, x): - x = x + Variable(self.positional_embedding[:, :x.size(1), :], requires_grad=False) + x = x + self.positional_embedding[:, :x.size(1), :] return x @@ -71,41 +71,39 @@ def forward(self, inputs, mask=None, **kwargs): q = inputs[0] q = self.Wq(q) q = torch.reshape(q, (q.size(0), q.size(1), self.h, -1)) - q = torch.transpose(q, (0, 2, 1, 3)) # batch, h, seq, dh + q = q.permute(0, 2, 1, 3) # batch, h, seq, dh k = inputs[1] k = self.Wk(k) k = torch.reshape(k, (k.size(0), k.size(1), self.h, -1)) - k = torch.transpose(k, (0, 2, 1, 3)) + k = k.permute(0, 2, 1, 3) v = inputs[2] v = self.Wv(v) v = torch.reshape(v, (v.size(0), v.size(1), self.h, -1)) - v = torch.transpose(v, (0, 2, 1, 3)) + v = v.permute(0, 2, 1, 3) - self.len_k = k.shape[2] - self.len_q = q.shape[2] + self.len_k = k.size(2) + self.len_q = q.size(2) E = self._get_left_embedding(self.len_q, self.len_k) - QE = torch.einsum('bhld,md->bhlm', q, E) + QE = torch.einsum('bhld,md->bhlm', [q, E]) QE = self._qe_masking(QE) # print(QE.shape) Srel = self._skewing(QE) - Kt = torch.transpose(k,[0, 1, 3, 2]) + Kt = k.permute(0, 1, 3, 2) QKt = torch.matmul(q, Kt) logits = QKt + Srel logits = logits / math.sqrt(self.dh) if mask is not None: - logits += (torch.Tensor.cast(mask, torch.float) * -1e9) + logits += (mask * -1e9) attention_weights = F.softmax(logits, -1) - # tf.print('logit result: \n', logits, output_stream=sys.stdout) attention = torch.matmul(attention_weights, v) - # tf.print('attention result: \n', attention, output_stream=sys.stdout) - out = torch.transpose(attention, (0, 2, 1, 3)) + out = attention.view(0, 2, 1, 3) out = torch.reshape(out, (out.size(0), -1, self.d)) out = self.fc(out) @@ -117,13 +115,13 @@ def _get_left_embedding(self, len_q, len_k): return e def _skewing(self, tensor: torch.Tensor): - padded = torch.pad(tensor, [[0, 0], [0,0], [0, 0], [1, 0]]) + padded = F.pad(tensor, [0, 0, 0, 0, 0, 0, 1, 0]) reshaped = torch.reshape(padded, shape=[-1, padded.size(1), padded.size(-1), padded.size(-2)]) Srel = reshaped[:, :, 1:, :] # print('Sre: {}'.format(Srel)) if self.len_k > self.len_q: - Srel = torch.pad(Srel, [[0,0], [0,0], [0,0], [0, self.len_k-self.len_q]]) + Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q]) elif self.len_k < self.len_q: Srel = Srel[:,:,:,:self.len_k] @@ -151,7 +149,7 @@ def forward(self, x, mask=None, **kwargs): attn_out = self.dropout1(attn_out) out1 = self.layernorm1(attn_out+x) - ffn_out = torch.nn.ReLU(self.FFN_pre(out1)) + ffn_out = F.relu(self.FFN_pre(out1)) ffn_out = self.FFN_suf(ffn_out) ffn_out = self.dropout2(ffn_out) out2 = self.layernorm2(out1+ffn_out) @@ -190,7 +188,7 @@ def forward(self, x, encode_out, mask=None, lookup_mask=None, w_out=False, **kwa attn_out2 = self.dropout2(attn_out2) attn_out2 = self.layernorm2(out1+attn_out2) - ffn_out = torch.nn.ReLU(self.FFN_pre(attn_out2)) + ffn_out = F.relu(self.FFN_pre(attn_out2)) ffn_out = self.FFN_suf(ffn_out) ffn_out = self.dropout3(ffn_out) out = self.layernorm3(attn_out2+ffn_out) @@ -209,9 +207,6 @@ def __init__(self, num_layers, d_model, input_vocab_size, rate=0.1, max_len=None self.num_layers = num_layers self.embedding = torch.nn.Embedding(num_embeddings=input_vocab_size, embedding_dim=d_model) - # self.embedding = keras.layers.Embedding(input_vocab_size, d_model) - # if max_len is not None: - # self.pos_encoding = PositionEmbedding(max_seq=max_len, embedding_dim=self.d_model) if True: self.pos_encoding = DynamicPositionEmbedding(self.d_model, max_seq=max_len) diff --git a/model.py b/model.py index 5e21ed2..138a681 100644 --- a/model.py +++ b/model.py @@ -4,6 +4,7 @@ import params as par import sys import torch +import torch.distributions as dist import json # import tensorflow_probability as tfp import random @@ -65,7 +66,7 @@ def generate(self, prior: list, length=2048, tf_board=False): result = tf.cast(result, tf.int32) decode_array = tf.concat([decode_array, tf.expand_dims(result, -1)], -1) else: - pdf = torch.distributions.OneHotCategorical(probs=result[:, -1]) + pdf = dist.OneHotCategorical(probs=result[:, -1]) result = pdf.sample(1) result = torch.transpose(result, (1, 0)) result = tf.cast(result, tf.int32) diff --git a/preprocess.py b/preprocess.py index 99b089d..fe335af 100644 --- a/preprocess.py +++ b/preprocess.py @@ -42,17 +42,6 @@ def preprocess_midi_files_under(midi_root, save_dir): pickle.dump(data, f) -# def _augumentation(seq): -# range_note = range(0, processor.RANGE_NOTE_ON+processor.RANGE_NOTE_OFF) -# range_time = range( -# processor.START_IDX['time_shift'], -# processor.START_IDX['time_shift']+processor.RANGE_TIME_SHIFT -# ) -# for idx, data in enumerate(seq): -# if data in range_note: -# - - class TFRecordsConverter(object): def __init__(self, midi_path, output_dir, num_shards_train=3, num_shards_test=1): @@ -113,10 +102,6 @@ def __write_to_records(self, output_path, indicies): es_seq = self.es_seq_list[i] ctrl_seq = self.ctrl_seq_list[i] - # example = tf.train.Example(features=tf.train.Features(feature={ - # 'label': TFRecordsConverter._int64_feature(label), - # 'text': TFRecordsConverter._bytes_feature(bytes(x, encoding='utf-8'))})) - if __name__ == '__main__': preprocess_midi_files_under(