Skip to content

Commit

Permalink
add config set, compatible for multi-gpu, init commit for apex-ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Oct 26, 2019
1 parent c4f8a9a commit cd841c4
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 49 deletions.
4 changes: 3 additions & 1 deletion config/base.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
experiment: 'debug'
max_seq: 2048
embedding_dim: 256
num_layers: 6
event_dim: 388
event_dim: 388
fp16:
8 changes: 8 additions & 0 deletions config/debug_train.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pickle_dir: 'MusicTransformer/dataset/processed'
epochs: 100
batch_size: 2
load_path:
dropout: 0.1
debug: 'true'
l_r: 0.001
label_smooth: 0.1
6 changes: 6 additions & 0 deletions config/large.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
experiment: 'mt_large'
max_seq: 2048
embedding_dim: 512
num_layers: 6
event_dim: 388
fp16:
6 changes: 3 additions & 3 deletions config/train.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pickle_dir: '/data/private/MusicTransformer/dataset/processed'
pickle_dir: 'MusicTransformer/dataset/processed'
epochs: 100
batch_size: 3
batch_size: 8
load_path:
dropout: 0.1
debug: 'true'
l_r: 0.001
label_smooth: 0.1
label_smooth: 0.1
8 changes: 8 additions & 0 deletions config/train_ddp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pickle_dir: 'MusicTransformer/dataset/processed'
epochs: 100
batch_size: 8
load_path:
dropout: 0.1
debug: 'true'
l_r: 0.001
label_smooth: 0.1
2 changes: 1 addition & 1 deletion custom/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def forward(self, inputs, mask=None, **kwargs):
logits = logits / math.sqrt(self.dh)

if mask is not None:
logits += (mask * -1e9).to(logits.dtype)
logits += (mask.to(torch.int64) * -1e9).to(logits.dtype)

attention_weights = F.softmax(logits, -1)
attention = torch.matmul(attention_weights, v)
Expand Down
28 changes: 21 additions & 7 deletions custom/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
:return:
"""
bool_acc = input.long() == target.long()
return bool_acc.sum() / bool_acc.numel()
return bool_acc.sum().to(torch.float) / bool_acc.numel()


class MockAccuracy(Accuracy):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
return super().forward(input, target)


class CategoricalAccuracy(Accuracy):
Expand All @@ -45,13 +53,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
class LogitsBucketting(_Metric):
def __init__(self, vocab_size):
super().__init__()
self.bucket = np.array([0] * vocab_size)

def forward(self, input: torch.Tensor, target: torch.Tensor):
self.bucket[input.flatten().to(torch.int32)] += 1

def get_bucket(self):
return self.bucket
return input.argmax(-1).flatten().to(torch.int32).cpu()


class MetricsSet(object):
Expand All @@ -64,4 +68,14 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor):

def forward(self, input: torch.Tensor, target: torch.Tensor):
# return [metric(input, target) for metric in self.metrics]
return {k: metric(input, target) for k, metric in self.metrics.items()}
return {
k: metric(input.to(target.device), target)
for k, metric in self.metrics.items()}


if __name__ == '__main__':
met = MockAccuracy()
test_tensor1 = torch.ones((3,2)).contiguous().cuda().to(non_blocking=True, dtype=torch.int)
test_tensor2 = torch.ones((3,2)).contiguous().cuda().to(non_blocking=True, dtype=torch.int)
test_tensor3 = torch.zeros((3,2))
print(met(test_tensor1, test_tensor2))
18 changes: 2 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)

def forward(self, x, length=None, writer=None):
if self.training:
if self.training or self.eval:
_, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token)
decoder, w = self.Decoder(x, look_ahead_mask)
decoder, w = self.Decoder(x, mask=look_ahead_mask)
fc = self.fc(decoder)
return fc.contiguous(), [weight.contiguous() for weight in w]
else:
Expand Down Expand Up @@ -72,17 +72,3 @@ def generate(self, decode_fn, prior: torch.Tensor, length=2048, tf_board_writer:
del look_ahead_mask
decode_array = decode_array[0]
return decode_array

# def teacher_forcing_forward(self, x, attn=False):
# _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token)
#
# predictions, w = self(
# x, lookup_mask=look_ahead_mask,
# )
#
# if self._debug:
# print('train step finished')
# if attn:
# return predictions, w
# else:
# return predictions
51 changes: 32 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from data import Data

import utils
import argparse
import datetime
import time

Expand Down Expand Up @@ -48,25 +47,27 @@
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

metric_set = MetricsSet({
'accuracy': CategoricalAccuracy(),
'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
'bucket': LogitsBucketting(config.vocab_size)
})

# multi-GPU set
if torch.cuda.device_count() > 1:
single_mt = mt
mt = torch.nn.DataParallel(mt)
mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count()-1)
else:
single_mt = mt

# init metric set
metric_set = MetricsSet({
'accuracy': CategoricalAccuracy(),
'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
'bucket': LogitsBucketting(config.vocab_size)
})

print(mt)
print('| Summary - Device Info : {}'.format(torch.cuda.device))

# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/mt/'+config.experiment+'/train'
eval_log_dir = 'logs/mt/'+config.experiment+'/eval'
train_log_dir = 'logs/'+config.experiment+'/'+current_time+'/train'
eval_log_dir = 'logs/'+config.experiment+'/'+current_time+'/eval'

train_summary_writer = SummaryWriter(train_log_dir)
eval_summary_writer = SummaryWriter(eval_log_dir)
Expand All @@ -93,6 +94,7 @@
loss.backward()
scheduler.step()
end_time = time.time()

if config.debug:
print("[Loss]: {}".format(loss))

Expand All @@ -103,30 +105,41 @@

# result_metrics = metric_set(sample, batch_y)
if b % 100 == 0:
eval_x, eval_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq, 'eval')
eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
single_mt.eval()
eval_x, eval_y = dataset.slide_seq2seq_batch(2, config.max_seq, 'eval')
eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, dtype=torch.int)
eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, dtype=torch.int)

eval_preiction, weights = single_mt.forward(eval_x)

eval_preiction, weights = mt.forward(eval_x)
eval_metrics = metric_set(eval_preiction, eval_y)
torch.save(single_mt.state_dict(), args.model_dir+'/train-{}.pth'.format(idx))
torch.save(single_mt.state_dict(), args.model_dir+'/train-{}.pth'.format(e))
if b == 0:
train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e)
train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e)
for i, weight in enumerate(weights):
attn_log_name = "attn/layer-{}".format(i)
utils.attention_image_summary(
attn_log_name, weight, step=idx, writer=eval_summary_writer)

eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx)
eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx)
eval_summary_writer.add_histogram("logits_bucket", metrics["bucket"].get_bucket(), global_step=idx)
for i, weight in enumerate(weights):
attn_log_name = "attn/layer-{}".format(i)
utils.attention_image_summary(attn_log_name, weight, step=idx, writer=eval_summary_writer)
eval_summary_writer.add_histogram("logits_bucket", eval_metrics['bucket'], global_step=idx)

print('\n====================================================')
print('Epoch/Batch: {}/{}'.format(e, b))
print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy']))
print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['accuracy']))
torch.cuda.empty_cache()
idx += 1

# switch output device to: gpu-1 ~ gpu-n
sw_start = time.time()
mt.output_device = idx % (torch.cuda.device_count() -1) + 1
sw_end = time.time()
if config.debug:
print('output switch time: {}'.format(sw_end - sw_start) )

torch.save(single_mt.state_dict(), args.model_dir+'/final.pth'.format(idx))
eval_summary_writer.close()
train_summary_writer.close()
Expand Down
157 changes: 157 additions & 0 deletions train_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from model import MusicTransformer
import custom
from custom.metrics import *
from custom.criterion import SmoothCrossEntropyLoss, CustomSchedule
from custom.config import config
from data import Data

import utils
import argparse
import datetime
import time
import os

from apex import amp
from apex.parallel import DistributedDataParallel
import torch
import torch.optim as optim
from tensorboardX import SummaryWriter


# set config
parser = custom.get_argument_parser()
# set local rank for torch.distribute
parser.add_argument('--local_rank', type=int)
args = parser.parse_args()
config.load(args.model_dir, args.configs, initialize=True)


config.device = torch.device('cuda')

# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
# the 'WORLD_SIZE' environment variable will also be set automatically.
config.distributed = False
if 'WORLD_SIZE' in os.environ:
config.distributed = int(os.environ['WORLD_SIZE']) > 1

# FOR DISTRIBUTED: Set the device according to local_rank.
torch.cuda.set_device(args.local_rank)

# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
# environment variables, and requires that you use init_method=`env://`.
torch.distributed.init_process_group(backend='nccl',
init_method='env://',
rank=args.local_rank,
world_size=4 * torch.cuda.device_count())


# load data
dataset = Data(config.pickle_dir)
print(dataset)


# load model
learning_rate = config.l_r


# define model
mt = MusicTransformer(
embedding_dim=config.embedding_dim,
vocab_size=config.vocab_size,
num_layer=config.num_layers,
max_seq=config.max_seq,
dropout=config.dropout,
debug=config.debug, loader_path=config.load_path
)
mt.to(config.device)
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

# Set model -> DDP
single_mt = mt
model, opt = amp.initialize(mt, scheduler.optimizer, opt_level="O1")
mt = DistributedDataParallel(model)


metric_set = MetricsSet({
'accuracy': CategoricalAccuracy().cpu(),
'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
'bucket': LogitsBucketting(config.vocab_size).cpu()
})

print(mt)
print('| Summary - Device Info : {}'.format(torch.cuda.device))

# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/'+config.experiment+'/'+current_time+'/train'
eval_log_dir = 'logs/'+config.experiment+'/'+current_time+'/eval'

train_summary_writer = SummaryWriter(train_log_dir)
eval_summary_writer = SummaryWriter(eval_log_dir)

# Train Start
print(">> Train start...")
idx = 0
for e in range(config.epochs):
print(">>> [Epoch was updated]")
for b in range(len(dataset.files) // config.batch_size):
scheduler.optimizer.zero_grad()
try:
batch_x, batch_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq)
batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
except IndexError:
continue

start_time = time.time()
mt.train()
sample, _ = mt.forward(batch_x)
metrics = metric_set(sample, batch_y)
loss = metrics['loss']
with amp.scale_loss(loss, scheduler.optimizer) as scaled_loss:
scaled_loss.backward()
scheduler.step()
end_time = time.time()

if config.debug:
print("[Loss]: {}".format(loss))

train_summary_writer.add_scalar('loss', metrics['loss'], global_step=idx)
train_summary_writer.add_scalar('accuracy', metrics['accuracy'], global_step=idx)
train_summary_writer.add_scalar('learning_rate', scheduler.rate(), global_step=idx)
train_summary_writer.add_scalar('iter_p_sec', end_time-start_time, global_step=idx)

# result_metrics = metric_set(sample, batch_y)
if b % 100 == 0:
single_mt.eval()
eval_x, eval_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq, 'eval')
eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, dtype=torch.int)
eval_y = torch.from_numpy(eval_y).contiguous().cpu().to(config.device, dtype=torch.int)

eval_preiction, weights = single_mt.forward(eval_x)
eval_metrics = metric_set(eval_preiction.cpu(), eval_y.cpu())
torch.save(single_mt.state_dict(), args.model_dir+'/train-{}.pth'.format(e))
if b == 0:
train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e)
train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e)
for i, weight in enumerate(weights):
attn_log_name = "attn/layer-{}".format(i)
utils.attention_image_summary(attn_log_name, weight, step=idx, writer=eval_summary_writer)

eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx)
eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx)
eval_summary_writer.add_histogram("logits_bucket", eval_metrics['bucket'], global_step=idx)

print('\n====================================================')
print('Epoch/Batch: {}/{}'.format(e, b))
print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy']))
print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['accuracy']))
torch.cuda.empty_cache()
idx += 1

torch.save(single_mt.state_dict(), args.model_dir+'/final.pth'.format(idx))
eval_summary_writer.close()
train_summary_writer.close()


Loading

0 comments on commit cd841c4

Please sign in to comment.