Skip to content

Commit

Permalink
Data parallel 디버깅 1차 완료
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Oct 22, 2019
1 parent a6513c5 commit 429775c
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 55 deletions.
8 changes: 5 additions & 3 deletions custom/criterion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Any

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import CrossEntropyLoss, _Loss
# from tensorflow.python.keras.optimizer_v2.learning_rate_schedule import LearningRateSchedule

Expand Down Expand Up @@ -29,13 +30,14 @@ class SmoothCrossEntropyLoss(_Loss):
"""
__constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction']

def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean'):
def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True):
assert 0.0 <= label_smoothing <= 1.0
super().__init__(reduction=reduction)

self.label_smoothing = label_smoothing
self.vocab_size = vocab_size
self.ignore_index = ignore_index
self.input_is_logits = is_logits

def forward(self, input, target):
"""
Expand All @@ -45,9 +47,9 @@ def forward(self, input, target):
Returns:
cross entropy: [1]
"""
mask = (target == self.ignore_index).unsqueeze(1)
mask = (target == self.ignore_index).unsqueeze(-1)

q = torch.nn.functional.one_hot(target, self.vocab_size).type(torch.float32)
q = F.one_hot(target, self.vocab_size).type(torch.float32)
u = 1.0 / self.vocab_size
q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u
q_prime = q_prime.masked_fill(mask, 0)
Expand Down
47 changes: 35 additions & 12 deletions custom/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import utils

import math as m
import numpy as np
import math
Expand Down Expand Up @@ -34,7 +36,7 @@ def __init__(self, embedding_dim, max_seq=2048):
self.positional_embedding = embed_sinusoid_list

def forward(self, x):
x = x + self.positional_embedding[:, :x.size(1), :]
x = x + torch.from_numpy(self.positional_embedding[:, :x.size(1), :]).to(x.device, dtype=x.dtype)
return x


Expand Down Expand Up @@ -85,10 +87,9 @@ def forward(self, inputs, mask=None, **kwargs):
self.len_k = k.size(2)
self.len_q = q.size(2)

E = self._get_left_embedding(self.len_q, self.len_k)
E = self._get_left_embedding(self.len_q, self.len_k).to(q.device)
QE = torch.einsum('bhld,md->bhlm', [q, E])
QE = self._qe_masking(QE)
# print(QE.shape)
Srel = self._skewing(QE)

Kt = k.permute(0, 1, 3, 2)
Expand All @@ -97,12 +98,12 @@ def forward(self, inputs, mask=None, **kwargs):
logits = logits / math.sqrt(self.dh)

if mask is not None:
logits += (mask * -1e9)
logits += (mask * -1e9).to(logits.dtype)

attention_weights = F.softmax(logits, -1)
attention = torch.matmul(attention_weights, v)

out = attention.view(0, 2, 1, 3)
out = attention.permute(0, 2, 1, 3)
out = torch.reshape(out, (out.size(0), -1, self.d))

out = self.fc(out)
Expand All @@ -114,17 +115,24 @@ def _get_left_embedding(self, len_q, len_k):
return e

def _skewing(self, tensor: torch.Tensor):
padded = F.pad(tensor, [0, 0, 0, 0, 0, 0, 1, 0])
reshaped = torch.reshape(padded, shape=[-1, padded.size(1), padded.size(-1), padded.size(-2)])
padded = F.pad(tensor, [1, 0, 0, 0, 0, 0, 0, 0])
reshaped = torch.reshape(padded, shape=[padded.size(0), padded.size(1), padded.size(-1), padded.size(-2)])
Srel = reshaped[:, :, 1:, :]

if self.len_k > self.len_q:
Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q])
elif self.len_k < self.len_q:
Srel = Srel[:, :, :, :self.len_k]

return Srel

@staticmethod
def _qe_masking(qe):
mask = utils.sequence_mask(
torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device),
qe.size()[-1])
mask = ~mask.to(mask.device)
return mask.to(qe.dtype) * qe


class EncoderLayer(torch.nn.Module):
def __init__(self, d_model, rate=0.1, h=16, additional=False, max_seq=2048):
Expand Down Expand Up @@ -208,18 +216,33 @@ def __init__(self, num_layers, d_model, input_vocab_size, rate=0.1, max_len=None
if True:
self.pos_encoding = DynamicPositionEmbedding(self.d_model, max_seq=max_len)

self.enc_layers = [EncoderLayer(d_model, rate, h=self.d_model // 64, additional=False, max_seq=max_len)
for i in range(num_layers)]
self.enc_layers = torch.nn.ModuleList(
[EncoderLayer(d_model, rate, h=self.d_model // 64, additional=False, max_seq=max_len)
for i in range(num_layers)])
self.dropout = torch.nn.Dropout(rate)

def call(self, x, mask=None):
def forward(self, x, mask=None):
weights = []
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= torch.sqrt(self.d_model)
x *= math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
for i in range(self.num_layers):
x, w = self.enc_layers[i](x, mask)
weights.append(w)
return x, weights # (batch_size, input_seq_len, d_model)


class MusicTransformerDataParallel(torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)

def forward(self, *inputs, **kwargs):
try:
return super().forward(*inputs)
except NotImplementedError:
return self.module(*inputs)
5 changes: 4 additions & 1 deletion custom/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
return super().forward(categorical_input, target)


class MetricsSet(_Metric):
class MetricsSet(object):
def __init__(self, metric_dict: Dict):
super().__init__()
self.metrics = metric_dict

def __call__(self, input: torch.Tensor, target: torch.Tensor):
return self.forward(input=input, target=target)

def forward(self, input: torch.Tensor, target: torch.Tensor):
# return [metric(input, target) for metric in self.metrics]
return {k: metric(input, target) for k, metric in self.metrics.items()}
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
inputs = np.array([[28]])
inputs = torch.from_numpy([inputs]).to(config.device)

result = mt.generate(inputs, beam=1, length=config.length, tf_board_writer=gen_summary_writer)
result = mt(inputs, config.length, gen_summary_writer)

for i in result:
print(i)
Expand Down
52 changes: 29 additions & 23 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from custom.layers import *
from custom.criterion import *
from custom.layers import Encoder
from custom.config import config

import sys
import torch
Expand Down Expand Up @@ -34,23 +35,29 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq)
self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)

def forward(self, x, lookup_mask=None):
decoder, w = self.Decoder(x, mask=lookup_mask)
fc = self.fc(decoder)
fc = fc.softmax(-1)
return fc, w
def forward(self, x, length=None, writer=None):
if self.training:
_, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token)
decoder, w = self.Decoder(x, look_ahead_mask)
fc = self.fc(decoder)
fc = fc.softmax(-1)
return fc.contiguous(), [weight.contiguous() for weight in w]
else:
return self.generate(self.Decoder, x, length, writer).contiguous()

def generate(self, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWriter = None):
def generate(self, decode_fn, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWriter = None):
decode_array = prior
for i in Bar('generating').iter(range(min(self.max_seq, length))):
if decode_array.shape[1] >= self.max_seq:
break
if i % 100 == 0:
print('generating... {}% completed'.format((i / min(self.max_seq, length)) * 100))
_, _, look_ahead_mask = \
utils.get_masked_with_pad_tensor(decode_array.shape[1], decode_array, decode_array)

result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
# result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
result, _ = decode_fn(decode_array, look_ahead_mask)
result = self.fc(result)
result = result.softmax(-1)

if tf_board_writer:
tf_board_writer.add_image("logits", result, global_step=i)

Expand All @@ -67,17 +74,16 @@ def generate(self, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWri
decode_array = decode_array[0]
return decode_array

def teacher_forcing_forward(self, x, attn=False):
x, _ = self.__prepare_train_data(x, x)
_, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x)

predictions, w = self.forward(
x, lookup_mask=look_ahead_mask,
)

if self._debug:
print('train step finished')
if attn:
return predictions, w
else:
return predictions
# def teacher_forcing_forward(self, x, attn=False):
# _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token)
#
# predictions, w = self(
# x, lookup_mask=look_ahead_mask,
# )
#
# if self._debug:
# print('train step finished')
# if attn:
# return predictions, w
# else:
# return predictions
28 changes: 18 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


# load data
dataset = Data('dataset/processed')
dataset = Data(config.pickle_dir)
print(dataset)


Expand All @@ -44,9 +44,10 @@
dropout=config.dropout,
debug=config.debug, loader_path=config.load_path
)
mt.to(config.device)
opt = optim.Adam(mt.parameters(), lr=config.l_r)
metric_set = MetricsSet({
'accuracy': Accuracy(),
'accuracy': CategoricalAccuracy(),
'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token)
})

Expand All @@ -57,6 +58,8 @@
else:
single_mt = mt

print(mt)

# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/mt_decoder/'+current_time+'/train'
Expand All @@ -66,32 +69,37 @@
eval_summary_writer = SummaryWriter(eval_log_dir)

# Train Start
print(">> Train start...")
idx = 0
for e in range(config.epochs):
print(">>> [Epoch was updated]")
for b in range(len(dataset.files) // config.batch_size):
opt.zero_grad()
try:
batch_x, batch_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq)
batch_x = torch.from_numpy(batch_x).to(args.device, non_blocking=True)
batch_y - torch.from_numpy(batch_y).to(args.device, non_blocking=True)
except:
batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True)
batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True)
except IndexError:
continue

sample = mt.teacher_forcing_forward(batch_x)
mt.train()
sample, _ = mt.forward(batch_x)
metrics = metric_set(sample, batch_y)
loss = metrics['loss']
loss.backward()
opt.step()
if config.debug:
print("[Loss]: {}".format(loss))

# result_metrics = metric_set(sample, batch_y)
if b % 100 == 0:
eval_x, eval_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq, 'eval')
eval_x = torch.from_numpy(eval_x).to(args.device, non_blocking=True)
eval_y = torch.from_numpy(eval_y).to(args.device, non_blocking=True)
eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, non_blocking=True)
eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, non_blocking=True)

eval_preiction, weights = mt.teacher_forcing_forward(eval_x)
eval_preiction, weights = mt.forward(eval_x)
eval_metrics = metric_set(eval_preiction, eval_y)
torch.save(single_mt, config.model_dir+'train-{}.pth'.format(idx))
torch.save(single_mt, args.model_dir+'train-{}.pth'.format(idx))
if b == 0:
train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e)
train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e)
Expand Down
12 changes: 7 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,17 @@ def get_masked_with_pad_tensor(size, src, trg, pad_token):
"""
src = src[:, None, None, :]
trg = trg[:, None, None, :]
src_pad_tensor = torch.ones_like(src) * pad_token
src_pad_tensor = torch.ones_like(src).to(src.device.type) * pad_token
src_mask = torch.equal(src, src_pad_tensor)
trg_mask = torch.equal(src, src_pad_tensor)
if trg is not None:
trg_pad_tensor = torch.ones_like(trg) * pad_token
trg_pad_tensor = torch.ones_like(trg).to(trg.device.type) * pad_token
dec_trg_mask = trg == trg_pad_tensor
# boolean reversing i.e) True * -1 + 1 = False
seq_mask = sequence_mask(torch.arange(1, size+1), size) * -1 + 1
look_ahead_mask = torch.max(dec_trg_mask, seq_mask)
seq_mask = sequence_mask(torch.arange(1, size+1).to(trg.device), size) * -1 + 1
# look_ahead_mask = torch.max(dec_trg_mask, seq_mask)
look_ahead_mask = dec_trg_mask | seq_mask

else:
trg_mask = None
look_ahead_mask = None
Expand Down Expand Up @@ -143,7 +145,7 @@ def attention_image_summary(name, attn, step=0, writer=None):
"""
num_heads = shape_list(attn)[1]
# [batch, query_length, memory_length, num_heads]
image = attn.view([0, 2, 3, 1])
image = attn.permute(0, 2, 3, 1)
image = torch.pow(image, 0.2) # for high-dynamic-range
# Each head will correspond to one of RGB.
# pad the heads to be a multiple of 3
Expand Down

0 comments on commit 429775c

Please sign in to comment.