From cd841c41c35c852e986499ddabc4394405cfa8aa Mon Sep 17 00:00:00 2001 From: kevin-yang Date: Sat, 26 Oct 2019 21:47:09 +0900 Subject: [PATCH] add config set, compatible for multi-gpu, init commit for apex-ddp --- config/base.yml | 4 +- config/debug_train.yml | 8 +++ config/large.yml | 6 ++ config/train.yml | 6 +- config/train_ddp.yml | 8 +++ custom/layers.py | 2 +- custom/metrics.py | 28 ++++++-- model.py | 18 +---- train.py | 51 ++++++++----- train_ddp.py | 157 +++++++++++++++++++++++++++++++++++++++++ utils.py | 4 +- 11 files changed, 243 insertions(+), 49 deletions(-) create mode 100644 config/debug_train.yml create mode 100644 config/large.yml create mode 100644 config/train_ddp.yml create mode 100644 train_ddp.py diff --git a/config/base.yml b/config/base.yml index 1dcf7d6..c26966f 100644 --- a/config/base.yml +++ b/config/base.yml @@ -1,4 +1,6 @@ +experiment: 'debug' max_seq: 2048 embedding_dim: 256 num_layers: 6 -event_dim: 388 \ No newline at end of file +event_dim: 388 +fp16: diff --git a/config/debug_train.yml b/config/debug_train.yml new file mode 100644 index 0000000..56aee6d --- /dev/null +++ b/config/debug_train.yml @@ -0,0 +1,8 @@ +pickle_dir: 'MusicTransformer/dataset/processed' +epochs: 100 +batch_size: 2 +load_path: +dropout: 0.1 +debug: 'true' +l_r: 0.001 +label_smooth: 0.1 diff --git a/config/large.yml b/config/large.yml new file mode 100644 index 0000000..752204d --- /dev/null +++ b/config/large.yml @@ -0,0 +1,6 @@ +experiment: 'mt_large' +max_seq: 2048 +embedding_dim: 512 +num_layers: 6 +event_dim: 388 +fp16: diff --git a/config/train.yml b/config/train.yml index 0fd6ea1..e66b361 100644 --- a/config/train.yml +++ b/config/train.yml @@ -1,8 +1,8 @@ -pickle_dir: '/data/private/MusicTransformer/dataset/processed' +pickle_dir: 'MusicTransformer/dataset/processed' epochs: 100 -batch_size: 3 +batch_size: 8 load_path: dropout: 0.1 debug: 'true' l_r: 0.001 -label_smooth: 0.1 \ No newline at end of file +label_smooth: 0.1 diff --git a/config/train_ddp.yml b/config/train_ddp.yml new file mode 100644 index 0000000..e66b361 --- /dev/null +++ b/config/train_ddp.yml @@ -0,0 +1,8 @@ +pickle_dir: 'MusicTransformer/dataset/processed' +epochs: 100 +batch_size: 8 +load_path: +dropout: 0.1 +debug: 'true' +l_r: 0.001 +label_smooth: 0.1 diff --git a/custom/layers.py b/custom/layers.py index 9a6bcb1..1d07701 100644 --- a/custom/layers.py +++ b/custom/layers.py @@ -98,7 +98,7 @@ def forward(self, inputs, mask=None, **kwargs): logits = logits / math.sqrt(self.dh) if mask is not None: - logits += (mask * -1e9).to(logits.dtype) + logits += (mask.to(torch.int64) * -1e9).to(logits.dtype) attention_weights = F.softmax(logits, -1) attention = torch.matmul(attention_weights, v) diff --git a/custom/metrics.py b/custom/metrics.py index c3e1923..aa00667 100644 --- a/custom/metrics.py +++ b/custom/metrics.py @@ -24,7 +24,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): :return: """ bool_acc = input.long() == target.long() - return bool_acc.sum() / bool_acc.numel() + return bool_acc.sum().to(torch.float) / bool_acc.numel() + + +class MockAccuracy(Accuracy): + def __init__(self): + super().__init__() + + def forward(self, input: torch.Tensor, target: torch.Tensor): + return super().forward(input, target) class CategoricalAccuracy(Accuracy): @@ -45,13 +53,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): class LogitsBucketting(_Metric): def __init__(self, vocab_size): super().__init__() - self.bucket = np.array([0] * vocab_size) def forward(self, input: torch.Tensor, target: torch.Tensor): - self.bucket[input.flatten().to(torch.int32)] += 1 - - def get_bucket(self): - return self.bucket + return input.argmax(-1).flatten().to(torch.int32).cpu() class MetricsSet(object): @@ -64,4 +68,14 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor): def forward(self, input: torch.Tensor, target: torch.Tensor): # return [metric(input, target) for metric in self.metrics] - return {k: metric(input, target) for k, metric in self.metrics.items()} + return { + k: metric(input.to(target.device), target) + for k, metric in self.metrics.items()} + + +if __name__ == '__main__': + met = MockAccuracy() + test_tensor1 = torch.ones((3,2)).contiguous().cuda().to(non_blocking=True, dtype=torch.int) + test_tensor2 = torch.ones((3,2)).contiguous().cuda().to(non_blocking=True, dtype=torch.int) + test_tensor3 = torch.zeros((3,2)) + print(met(test_tensor1, test_tensor2)) diff --git a/model.py b/model.py index 5951818..fdb6e92 100644 --- a/model.py +++ b/model.py @@ -36,9 +36,9 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size) def forward(self, x, length=None, writer=None): - if self.training: + if self.training or self.eval: _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token) - decoder, w = self.Decoder(x, look_ahead_mask) + decoder, w = self.Decoder(x, mask=look_ahead_mask) fc = self.fc(decoder) return fc.contiguous(), [weight.contiguous() for weight in w] else: @@ -72,17 +72,3 @@ def generate(self, decode_fn, prior: torch.Tensor, length=2048, tf_board_writer: del look_ahead_mask decode_array = decode_array[0] return decode_array - - # def teacher_forcing_forward(self, x, attn=False): - # _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token) - # - # predictions, w = self( - # x, lookup_mask=look_ahead_mask, - # ) - # - # if self._debug: - # print('train step finished') - # if attn: - # return predictions, w - # else: - # return predictions diff --git a/train.py b/train.py index fb1c3ab..755bea9 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ from data import Data import utils -import argparse import datetime import time @@ -48,25 +47,27 @@ opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) scheduler = CustomSchedule(config.embedding_dim, optimizer=opt) -metric_set = MetricsSet({ - 'accuracy': CategoricalAccuracy(), - 'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token), - 'bucket': LogitsBucketting(config.vocab_size) -}) - # multi-GPU set if torch.cuda.device_count() > 1: single_mt = mt - mt = torch.nn.DataParallel(mt) + mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count()-1) else: single_mt = mt +# init metric set +metric_set = MetricsSet({ + 'accuracy': CategoricalAccuracy(), + 'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token), + 'bucket': LogitsBucketting(config.vocab_size) +}) + print(mt) +print('| Summary - Device Info : {}'.format(torch.cuda.device)) # define tensorboard writer current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') -train_log_dir = 'logs/mt/'+config.experiment+'/train' -eval_log_dir = 'logs/mt/'+config.experiment+'/eval' +train_log_dir = 'logs/'+config.experiment+'/'+current_time+'/train' +eval_log_dir = 'logs/'+config.experiment+'/'+current_time+'/eval' train_summary_writer = SummaryWriter(train_log_dir) eval_summary_writer = SummaryWriter(eval_log_dir) @@ -93,6 +94,7 @@ loss.backward() scheduler.step() end_time = time.time() + if config.debug: print("[Loss]: {}".format(loss)) @@ -103,30 +105,41 @@ # result_metrics = metric_set(sample, batch_y) if b % 100 == 0: - eval_x, eval_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq, 'eval') - eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int) - eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int) + single_mt.eval() + eval_x, eval_y = dataset.slide_seq2seq_batch(2, config.max_seq, 'eval') + eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, dtype=torch.int) + eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, dtype=torch.int) + + eval_preiction, weights = single_mt.forward(eval_x) - eval_preiction, weights = mt.forward(eval_x) eval_metrics = metric_set(eval_preiction, eval_y) - torch.save(single_mt.state_dict(), args.model_dir+'/train-{}.pth'.format(idx)) + torch.save(single_mt.state_dict(), args.model_dir+'/train-{}.pth'.format(e)) if b == 0: train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e) train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e) + for i, weight in enumerate(weights): + attn_log_name = "attn/layer-{}".format(i) + utils.attention_image_summary( + attn_log_name, weight, step=idx, writer=eval_summary_writer) eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx) eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx) - eval_summary_writer.add_histogram("logits_bucket", metrics["bucket"].get_bucket(), global_step=idx) - for i, weight in enumerate(weights): - attn_log_name = "attn/layer-{}".format(i) - utils.attention_image_summary(attn_log_name, weight, step=idx, writer=eval_summary_writer) + eval_summary_writer.add_histogram("logits_bucket", eval_metrics['bucket'], global_step=idx) print('\n====================================================') print('Epoch/Batch: {}/{}'.format(e, b)) print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy'])) print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['accuracy'])) + torch.cuda.empty_cache() idx += 1 + # switch output device to: gpu-1 ~ gpu-n + sw_start = time.time() + mt.output_device = idx % (torch.cuda.device_count() -1) + 1 + sw_end = time.time() + if config.debug: + print('output switch time: {}'.format(sw_end - sw_start) ) + torch.save(single_mt.state_dict(), args.model_dir+'/final.pth'.format(idx)) eval_summary_writer.close() train_summary_writer.close() diff --git a/train_ddp.py b/train_ddp.py new file mode 100644 index 0000000..655fd46 --- /dev/null +++ b/train_ddp.py @@ -0,0 +1,157 @@ +from model import MusicTransformer +import custom +from custom.metrics import * +from custom.criterion import SmoothCrossEntropyLoss, CustomSchedule +from custom.config import config +from data import Data + +import utils +import argparse +import datetime +import time +import os + +from apex import amp +from apex.parallel import DistributedDataParallel +import torch +import torch.optim as optim +from tensorboardX import SummaryWriter + + +# set config +parser = custom.get_argument_parser() +# set local rank for torch.distribute +parser.add_argument('--local_rank', type=int) +args = parser.parse_args() +config.load(args.model_dir, args.configs, initialize=True) + + +config.device = torch.device('cuda') + +# FOR DISTRIBUTED: If we are running under torch.distributed.launch, +# the 'WORLD_SIZE' environment variable will also be set automatically. +config.distributed = False +if 'WORLD_SIZE' in os.environ: + config.distributed = int(os.environ['WORLD_SIZE']) > 1 + +# FOR DISTRIBUTED: Set the device according to local_rank. +torch.cuda.set_device(args.local_rank) + +# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide +# environment variables, and requires that you use init_method=`env://`. +torch.distributed.init_process_group(backend='nccl', + init_method='env://', + rank=args.local_rank, + world_size=4 * torch.cuda.device_count()) + + +# load data +dataset = Data(config.pickle_dir) +print(dataset) + + +# load model +learning_rate = config.l_r + + +# define model +mt = MusicTransformer( + embedding_dim=config.embedding_dim, + vocab_size=config.vocab_size, + num_layer=config.num_layers, + max_seq=config.max_seq, + dropout=config.dropout, + debug=config.debug, loader_path=config.load_path +) +mt.to(config.device) +opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) +scheduler = CustomSchedule(config.embedding_dim, optimizer=opt) + +# Set model -> DDP +single_mt = mt +model, opt = amp.initialize(mt, scheduler.optimizer, opt_level="O1") +mt = DistributedDataParallel(model) + + +metric_set = MetricsSet({ + 'accuracy': CategoricalAccuracy().cpu(), + 'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token), + 'bucket': LogitsBucketting(config.vocab_size).cpu() +}) + +print(mt) +print('| Summary - Device Info : {}'.format(torch.cuda.device)) + +# define tensorboard writer +current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') +train_log_dir = 'logs/'+config.experiment+'/'+current_time+'/train' +eval_log_dir = 'logs/'+config.experiment+'/'+current_time+'/eval' + +train_summary_writer = SummaryWriter(train_log_dir) +eval_summary_writer = SummaryWriter(eval_log_dir) + +# Train Start +print(">> Train start...") +idx = 0 +for e in range(config.epochs): + print(">>> [Epoch was updated]") + for b in range(len(dataset.files) // config.batch_size): + scheduler.optimizer.zero_grad() + try: + batch_x, batch_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq) + batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int) + batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int) + except IndexError: + continue + + start_time = time.time() + mt.train() + sample, _ = mt.forward(batch_x) + metrics = metric_set(sample, batch_y) + loss = metrics['loss'] + with amp.scale_loss(loss, scheduler.optimizer) as scaled_loss: + scaled_loss.backward() + scheduler.step() + end_time = time.time() + + if config.debug: + print("[Loss]: {}".format(loss)) + + train_summary_writer.add_scalar('loss', metrics['loss'], global_step=idx) + train_summary_writer.add_scalar('accuracy', metrics['accuracy'], global_step=idx) + train_summary_writer.add_scalar('learning_rate', scheduler.rate(), global_step=idx) + train_summary_writer.add_scalar('iter_p_sec', end_time-start_time, global_step=idx) + + # result_metrics = metric_set(sample, batch_y) + if b % 100 == 0: + single_mt.eval() + eval_x, eval_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq, 'eval') + eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, dtype=torch.int) + eval_y = torch.from_numpy(eval_y).contiguous().cpu().to(config.device, dtype=torch.int) + + eval_preiction, weights = single_mt.forward(eval_x) + eval_metrics = metric_set(eval_preiction.cpu(), eval_y.cpu()) + torch.save(single_mt.state_dict(), args.model_dir+'/train-{}.pth'.format(e)) + if b == 0: + train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e) + train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e) + for i, weight in enumerate(weights): + attn_log_name = "attn/layer-{}".format(i) + utils.attention_image_summary(attn_log_name, weight, step=idx, writer=eval_summary_writer) + + eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx) + eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx) + eval_summary_writer.add_histogram("logits_bucket", eval_metrics['bucket'], global_step=idx) + + print('\n====================================================') + print('Epoch/Batch: {}/{}'.format(e, b)) + print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy'])) + print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['accuracy'])) + torch.cuda.empty_cache() + idx += 1 + +torch.save(single_mt.state_dict(), args.model_dir+'/final.pth'.format(idx)) +eval_summary_writer.close() +train_summary_writer.close() + + diff --git a/utils.py b/utils.py index 72d023a..bdf1edb 100644 --- a/utils.py +++ b/utils.py @@ -72,7 +72,7 @@ def get_masked_with_pad_tensor(size, src, trg, pad_token): trg_pad_tensor = torch.ones_like(trg).to(trg.device.type) * pad_token dec_trg_mask = trg == trg_pad_tensor # boolean reversing i.e) True * -1 + 1 = False - seq_mask = sequence_mask(torch.arange(1, size+1).to(trg.device), size) * -1 + 1 + seq_mask = ~sequence_mask(torch.arange(1, size+1).to(trg.device), size) # look_ahead_mask = torch.max(dec_trg_mask, seq_mask) look_ahead_mask = dec_trg_mask | seq_mask @@ -89,7 +89,7 @@ def get_mask_tensor(size): :return: """ # boolean reversing i.e) True * -1 + 1 = False - seq_mask = sequence_mask(torch.arange(1, size + 1), size) * -1 + 1 + seq_mask = ~sequence_mask(torch.arange(1, size + 1), size) return seq_mask