Skip to content

Commit

Permalink
implement module to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Sep 3, 2019
1 parent 2e8f307 commit f56b158
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 64 deletions.
75 changes: 48 additions & 27 deletions custom/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,43 +70,43 @@ def forward(self, inputs, mask=None, **kwargs):
"""
q = inputs[0]
q = self.Wq(q)
q = torch.Tensor.reshape(q, (q.size(0), q.size(1), self.h, -1))
q = torch.Tensor.transpose(q, (0, 2, 1, 3)) # batch, h, seq, dh
q = torch.reshape(q, (q.size(0), q.size(1), self.h, -1))
q = torch.transpose(q, (0, 2, 1, 3)) # batch, h, seq, dh

k = inputs[1]
k = self.Wk(k)
k = torch.Tensor.reshape(k, (k.size(0), k.size(1), self.h, -1))
k = torch.Tensor.transpose(k, (0, 2, 1, 3))
k = torch.reshape(k, (k.size(0), k.size(1), self.h, -1))
k = torch.transpose(k, (0, 2, 1, 3))

v = inputs[2]
v = self.Wv(v)
v = torch.Tensor.reshape(v, (v.size(0), v.size(1), self.h, -1))
v = torch.Tensor.transpose(v, (0, 2, 1, 3))
v = torch.reshape(v, (v.size(0), v.size(1), self.h, -1))
v = torch.transpose(v, (0, 2, 1, 3))

self.len_k = k.shape[2]
self.len_q = q.shape[2]

E = self._get_left_embedding(self.len_q, self.len_k)
QE = torch.Tensor.einsum('bhld,md->bhlm', q, E)
QE = torch.einsum('bhld,md->bhlm', q, E)
QE = self._qe_masking(QE)
# print(QE.shape)
Srel = self._skewing(QE)

Kt = torch.Tensor.transpose(k,[0, 1, 3, 2])
QKt = torch.Tensor.matmul(q, Kt)
Kt = torch.transpose(k,[0, 1, 3, 2])
QKt = torch.matmul(q, Kt)
logits = QKt + Srel
logits = logits / math.sqrt(self.dh)

if mask is not None:
logits += (torch.Tensor.cast(mask, torch.Tensor.float) * -1e9)
logits += (torch.Tensor.cast(mask, torch.float) * -1e9)

attention_weights = F.softmax(logits, -1)
# tf.print('logit result: \n', logits, output_stream=sys.stdout)
attention = torch.Tensor.matmul(attention_weights, v)
attention = torch.matmul(attention_weights, v)
# tf.print('attention result: \n', attention, output_stream=sys.stdout)

out = torch.Tensor.transpose(attention, (0, 2, 1, 3))
out = torch.Tensor.reshape(out, (out.size(0), -1, self.d))
out = torch.transpose(attention, (0, 2, 1, 3))
out = torch.reshape(out, (out.size(0), -1, self.d))

out = self.fc(out)
return out, attention_weights
Expand All @@ -116,24 +116,14 @@ def _get_left_embedding(self, len_q, len_k):
e = self.E[starting_point:,:]
return e

# @staticmethod
# def _qe_masking(qe):
# mask = tf.sequence_mask(
# tf.range(qe.shape[-1] -1, qe.shape[-1] - qe.shape[-2] -1, -1), qe.shape[-1])
#
# mask = tf.logical_not(mask)
# mask = tf.cast(mask, tf.float32)
#
# return mask * qe

def _skewing(self, tensor: torch.Tensor):
padded = torch.Tensor.pad(tensor, [[0, 0], [0,0], [0, 0], [1, 0]])
reshaped = torch.Tensor.reshape(padded, shape=[-1, padded.size(1), padded.size(-1), padded.size(-2)])
padded = torch.pad(tensor, [[0, 0], [0,0], [0, 0], [1, 0]])
reshaped = torch.reshape(padded, shape=[-1, padded.size(1), padded.size(-1), padded.size(-2)])
Srel = reshaped[:, :, 1:, :]
# print('Sre: {}'.format(Srel))

if self.len_k > self.len_q:
Srel = torch.Tensor.pad(Srel, [[0,0], [0,0], [0,0], [0, self.len_k-self.len_q]])
Srel = torch.pad(Srel, [[0,0], [0,0], [0,0], [0, self.len_k-self.len_q]])
elif self.len_k < self.len_q:
Srel = Srel[:,:,:,:self.len_k]

Expand Down Expand Up @@ -208,4 +198,35 @@ def forward(self, x, encode_out, mask=None, lookup_mask=None, w_out=False, **kwa
if w_out:
return out, aw1, aw2
else:
return out
return out


class Encoder(torch.nn.Module):
def __init__(self, num_layers, d_model, input_vocab_size, rate=0.1, max_len=None):
super(Encoder, self).__init__()

self.d_model = d_model
self.num_layers = num_layers

self.embedding = torch.nn.Embedding(num_embeddings=input_vocab_size, embedding_dim=d_model)
# self.embedding = keras.layers.Embedding(input_vocab_size, d_model)
# if max_len is not None:
# self.pos_encoding = PositionEmbedding(max_seq=max_len, embedding_dim=self.d_model)
if True:
self.pos_encoding = DynamicPositionEmbedding(self.d_model, max_seq=max_len)

self.enc_layers = [EncoderLayer(d_model, rate, h=self.d_model // 64, additional=False, max_seq=max_len)
for i in range(num_layers)]
self.dropout = torch.nn.Dropout(rate)

def call(self, x, mask=None):
weights = []
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= torch.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
for i in range(self.num_layers):
x, w = self.enc_layers[i](x, mask)
weights.append(w)
return x, weights # (batch_size, input_seq_len, d_model)
51 changes: 48 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import torch
import json
import tensorflow_probability as tfp
# import tensorflow_probability as tfp
import random
import utils

Expand All @@ -25,10 +25,55 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
self.vocab_size = vocab_size
self.dist = dist

self.Decoder = Encoder(
num_layers=self.num_layer, d_model=self.embedding_dim,
input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq)
self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)

self._set_metrics()

def forward(self, x):
decoder, w = self.Decoder(x, mask=lookup_mask)
fc = self.fc(decoder)
if self.training:
return fc
elif eval:
return fc, w
else:
return F.softmax(fc)

def forward(self, *input):
pass
def generate(self, prior: list, length=2048, tf_board=False):
decode_array = prior
decode_array = tf.constant([decode_array])
for i in range(min(self.max_seq, length)):
# print(decode_array.shape[1])
if decode_array.shape[1] >= self.max_seq:
break
if i % 100 == 0:
print('generating... {}% completed'.format((i / min(self.max_seq, length)) * 100))
_, _, look_ahead_mask = \
utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array)

result = self.call(decode_array, lookup_mask=look_ahead_mask, training=False)
if tf_board:
tf.summary.image('generate_vector', tf.expand_dims(result, -1), i)
# import sys
# tf.print('[debug out:]', result, sys.stdout )
u = random.uniform(0, 1)
if u > 1:
result = tf.argmax(result[:, -1], -1)
result = tf.cast(result, tf.int32)
decode_array = tf.concat([decode_array, tf.expand_dims(result, -1)], -1)
else:
pdf = torch.distributions.OneHotCategorical(probs=result[:, -1])
result = pdf.sample(1)
result = torch.transpose(result, (1, 0))
result = tf.cast(result, tf.int32)
decode_array = tf.concat([decode_array, result], -1)
# decode_array = tf.concat([decode_array, tf.expand_dims(result[:, -1], 0)], -1)
del look_ahead_mask
decode_array = decode_array[0]
return decode_array


class MusicTransformerDecoder(torch.nn.Module):
Expand Down
76 changes: 42 additions & 34 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import numpy as np
from deprecated.sequence import EventSeq, ControlSeq
import torch as torch
import torch
import params as par


Expand Down Expand Up @@ -94,17 +94,17 @@ def get_masked_with_pad_tensor(size, src, trg):
:param trg: target tensor
:return:
"""
src = tf.cast(src[:, tf.newaxis, tf.newaxis, :], tf.int32)
trg = tf.cast(trg[:, tf.newaxis, tf.newaxis, :], tf.int32)
src_pad_tensor = tf.ones_like(src) * par.pad_token
src_mask = tf.cast(tf.equal(src, src_pad_tensor), dtype=tf.int32)
trg_mask = tf.cast(tf.equal(src, src_pad_tensor), dtype=tf.int32)
src = src[:, None, None, :]
trg = trg[:, None, None, :]
src_pad_tensor = torch.ones_like(src) * par.pad_token
src_mask = torch.equal(src, src_pad_tensor)
trg_mask = torch.equal(src, src_pad_tensor)
if trg is not None:
trg_pad_tensor = tf.ones_like(trg) * par.pad_token
dec_trg_mask = tf.cast(tf.equal(trg, trg_pad_tensor), dtype=tf.int32)
trg_pad_tensor = torch.ones_like(trg) * par.pad_token
dec_trg_mask = torch.equal(trg, trg_pad_tensor)
# boolean reversing i.e) True * -1 + 1 = False
seq_mask = tf.sequence_mask(list(range(1, size+1)), size, dtype=tf.int32) * -1 + 1
look_ahead_mask = tf.cast(tf.maximum(dec_trg_mask, seq_mask), dtype=tf.int32)
seq_mask = sequence_mask(torch.arange(1, size+1), size) * -1 + 1
look_ahead_mask = torch.max(dec_trg_mask, seq_mask)
else:
trg_mask = None
look_ahead_mask = None
Expand All @@ -118,7 +118,7 @@ def get_mask_tensor(size):
:return:
"""
# boolean reversing i.e) True * -1 + 1 = False
seq_mask = tf.sequence_mask(range(1, size + 1), size, dtype=tf.int32) * -1 + 1
seq_mask = sequence_mask(torch.arange(1, size + 1), size) * -1 + 1
return seq_mask


Expand All @@ -139,11 +139,11 @@ def pad_with_length(max_length: int, seq: list, pad_val: float=par.pad_token):
return seq + pad


def append_token(data: tf.Tensor):
start_token = tf.ones((data.shape[0], 1), dtype=data.dtype) * par.token_sos
end_token = tf.ones((data.shape[0], 1), dtype=data.dtype) * par.token_eos
def append_token(data: torch.Tensor):
start_token = torch.ones((data.size(0), 1), dtype=data.dtype) * par.token_sos
end_token = torch.ones((data.size(0), 1), dtype=data.dtype) * par.token_eos

return tf.concat([start_token, data, end_token], -1)
return torch.cat([start_token, data, end_token], -1)


def weights2boards(weights, dir, step): # weights stored weight[layer][w1,w2]
Expand All @@ -154,17 +154,17 @@ def weights2boards(weights, dir, step): # weights stored weight[layer][w1,w2]


def shape_list(x):
"""Shape list"""
x_shape = tf.shape(x)
x_get_shape = x.get_shape().as_list()

res = []
for i, d in enumerate(x_get_shape):
if d is not None:
res.append(d)
else:
res.append(x_shape[i])
return res
"""Shape list"""
x_shape = x.size()
x_get_shape = list(x.size())

res = []
for i, d in enumerate(x_get_shape):
if d is not None:
res.append(d)
else:
res.append(x_shape[i])
return res


def attention_image_summary(attn, step=0):
Expand All @@ -182,13 +182,13 @@ def attention_image_summary(attn, step=0):
"""
num_heads = shape_list(attn)[1]
# [batch, query_length, memory_length, num_heads]
image = tf.transpose(attn, [0, 2, 3, 1])
image = tf.math.pow(image, 0.2) # for high-dynamic-range
image = attn.view([0, 2, 3, 1])
image = torch.pow(image, 0.2) # for high-dynamic-range
# Each head will correspond to one of RGB.
# pad the heads to be a multiple of 3
image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, tf.math.mod(-num_heads, 3)]])
image = split_last_dimension(image, 3)
image = tf.reduce_max(image, 4)
image = torch.max(image, dim=4)
tf.summary.image("attention", image, max_outputs=1, step=step)


Expand All @@ -205,22 +205,30 @@ def split_last_dimension(x, n):
m = x_shape[-1]
if isinstance(m, int) and isinstance(n, int):
assert m % n == 0
return tf.reshape(x, x_shape[:-1] + [n, m // n])
return torch.reshape(x, x_shape[:-1] + [n, m // n])


def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.Tensor.from_numpy(subsequent_mask) == 0
return torch.from_numpy(subsequent_mask) == 0


def sequence_mask(length, max_length=None):
"""Tensorflow의 sequence_mask를 구현"""
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)


if __name__ == '__main__':

s = np.array([np.array([1,2]*50),np.array([1,2,3,4]*25)])
s = np.array([np.array([1, 2]*50),np.array([1, 2, 3, 4]*25)])

t = np.array([np.array([2,3,4,5,6]*20),np.array([1,2,3,4,5]*20)])
t = np.array([np.array([2, 3, 4, 5, 6]*20), np.array([1, 2, 3, 4, 5]*20)])
print(t.shape)

print(get_masked_with_pad_tensor(100,s,t))
print(get_masked_with_pad_tensor(100, s, t))

0 comments on commit f56b158

Please sign in to comment.