Skip to content

Commit

Permalink
torch 모듈 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Sep 3, 2019
1 parent f56b158 commit 7f334ea
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 35 deletions.
33 changes: 14 additions & 19 deletions custom/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, embedding_dim, max_seq=2048):
self.positional_embedding = embed_sinusoid_list

def forward(self, x):
x = x + Variable(self.positional_embedding[:, :x.size(1), :], requires_grad=False)
x = x + self.positional_embedding[:, :x.size(1), :]
return x


Expand Down Expand Up @@ -71,41 +71,39 @@ def forward(self, inputs, mask=None, **kwargs):
q = inputs[0]
q = self.Wq(q)
q = torch.reshape(q, (q.size(0), q.size(1), self.h, -1))
q = torch.transpose(q, (0, 2, 1, 3)) # batch, h, seq, dh
q = q.permute(0, 2, 1, 3) # batch, h, seq, dh

k = inputs[1]
k = self.Wk(k)
k = torch.reshape(k, (k.size(0), k.size(1), self.h, -1))
k = torch.transpose(k, (0, 2, 1, 3))
k = k.permute(0, 2, 1, 3)

v = inputs[2]
v = self.Wv(v)
v = torch.reshape(v, (v.size(0), v.size(1), self.h, -1))
v = torch.transpose(v, (0, 2, 1, 3))
v = v.permute(0, 2, 1, 3)

self.len_k = k.shape[2]
self.len_q = q.shape[2]
self.len_k = k.size(2)
self.len_q = q.size(2)

E = self._get_left_embedding(self.len_q, self.len_k)
QE = torch.einsum('bhld,md->bhlm', q, E)
QE = torch.einsum('bhld,md->bhlm', [q, E])
QE = self._qe_masking(QE)
# print(QE.shape)
Srel = self._skewing(QE)

Kt = torch.transpose(k,[0, 1, 3, 2])
Kt = k.permute(0, 1, 3, 2)
QKt = torch.matmul(q, Kt)
logits = QKt + Srel
logits = logits / math.sqrt(self.dh)

if mask is not None:
logits += (torch.Tensor.cast(mask, torch.float) * -1e9)
logits += (mask * -1e9)

attention_weights = F.softmax(logits, -1)
# tf.print('logit result: \n', logits, output_stream=sys.stdout)
attention = torch.matmul(attention_weights, v)
# tf.print('attention result: \n', attention, output_stream=sys.stdout)

out = torch.transpose(attention, (0, 2, 1, 3))
out = attention.view(0, 2, 1, 3)
out = torch.reshape(out, (out.size(0), -1, self.d))

out = self.fc(out)
Expand All @@ -117,13 +115,13 @@ def _get_left_embedding(self, len_q, len_k):
return e

def _skewing(self, tensor: torch.Tensor):
padded = torch.pad(tensor, [[0, 0], [0,0], [0, 0], [1, 0]])
padded = F.pad(tensor, [0, 0, 0, 0, 0, 0, 1, 0])
reshaped = torch.reshape(padded, shape=[-1, padded.size(1), padded.size(-1), padded.size(-2)])
Srel = reshaped[:, :, 1:, :]
# print('Sre: {}'.format(Srel))

if self.len_k > self.len_q:
Srel = torch.pad(Srel, [[0,0], [0,0], [0,0], [0, self.len_k-self.len_q]])
Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q])
elif self.len_k < self.len_q:
Srel = Srel[:,:,:,:self.len_k]

Expand Down Expand Up @@ -151,7 +149,7 @@ def forward(self, x, mask=None, **kwargs):
attn_out = self.dropout1(attn_out)
out1 = self.layernorm1(attn_out+x)

ffn_out = torch.nn.ReLU(self.FFN_pre(out1))
ffn_out = F.relu(self.FFN_pre(out1))
ffn_out = self.FFN_suf(ffn_out)
ffn_out = self.dropout2(ffn_out)
out2 = self.layernorm2(out1+ffn_out)
Expand Down Expand Up @@ -190,7 +188,7 @@ def forward(self, x, encode_out, mask=None, lookup_mask=None, w_out=False, **kwa
attn_out2 = self.dropout2(attn_out2)
attn_out2 = self.layernorm2(out1+attn_out2)

ffn_out = torch.nn.ReLU(self.FFN_pre(attn_out2))
ffn_out = F.relu(self.FFN_pre(attn_out2))
ffn_out = self.FFN_suf(ffn_out)
ffn_out = self.dropout3(ffn_out)
out = self.layernorm3(attn_out2+ffn_out)
Expand All @@ -209,9 +207,6 @@ def __init__(self, num_layers, d_model, input_vocab_size, rate=0.1, max_len=None
self.num_layers = num_layers

self.embedding = torch.nn.Embedding(num_embeddings=input_vocab_size, embedding_dim=d_model)
# self.embedding = keras.layers.Embedding(input_vocab_size, d_model)
# if max_len is not None:
# self.pos_encoding = PositionEmbedding(max_seq=max_len, embedding_dim=self.d_model)
if True:
self.pos_encoding = DynamicPositionEmbedding(self.d_model, max_seq=max_len)

Expand Down
3 changes: 2 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import params as par
import sys
import torch
import torch.distributions as dist
import json
# import tensorflow_probability as tfp
import random
Expand Down Expand Up @@ -65,7 +66,7 @@ def generate(self, prior: list, length=2048, tf_board=False):
result = tf.cast(result, tf.int32)
decode_array = tf.concat([decode_array, tf.expand_dims(result, -1)], -1)
else:
pdf = torch.distributions.OneHotCategorical(probs=result[:, -1])
pdf = dist.OneHotCategorical(probs=result[:, -1])
result = pdf.sample(1)
result = torch.transpose(result, (1, 0))
result = tf.cast(result, tf.int32)
Expand Down
15 changes: 0 additions & 15 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@ def preprocess_midi_files_under(midi_root, save_dir):
pickle.dump(data, f)


# def _augumentation(seq):
# range_note = range(0, processor.RANGE_NOTE_ON+processor.RANGE_NOTE_OFF)
# range_time = range(
# processor.START_IDX['time_shift'],
# processor.START_IDX['time_shift']+processor.RANGE_TIME_SHIFT
# )
# for idx, data in enumerate(seq):
# if data in range_note:
#


class TFRecordsConverter(object):
def __init__(self, midi_path, output_dir,
num_shards_train=3, num_shards_test=1):
Expand Down Expand Up @@ -113,10 +102,6 @@ def __write_to_records(self, output_path, indicies):
es_seq = self.es_seq_list[i]
ctrl_seq = self.ctrl_seq_list[i]

# example = tf.train.Example(features=tf.train.Features(feature={
# 'label': TFRecordsConverter._int64_feature(label),
# 'text': TFRecordsConverter._bytes_feature(bytes(x, encoding='utf-8'))}))


if __name__ == '__main__':
preprocess_midi_files_under(
Expand Down

0 comments on commit 7f334ea

Please sign in to comment.