From 7b6c6c5fc1adf4fc685125deccb567e4c1bd4f4e Mon Sep 17 00:00:00 2001 From: adamoudad Date: Mon, 30 Nov 2020 08:13:21 +0900 Subject: [PATCH 1/2] Fix training for single GPU --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 6d7b618..1c97d37 100644 --- a/train.py +++ b/train.py @@ -135,7 +135,8 @@ # switch output device to: gpu-1 ~ gpu-n sw_start = time.time() - mt.output_device = idx % (torch.cuda.device_count() -1) + 1 + if torch.cuda.device_count() > 1: + mt.output_device = idx % (torch.cuda.device_count() -1) + 1 sw_end = time.time() if config.debug: print('output switch time: {}'.format(sw_end - sw_start) ) From 68911c7c74252ebecdf20055f80f3b7a437d91f9 Mon Sep 17 00:00:00 2001 From: adamoudad Date: Mon, 30 Nov 2020 08:25:15 +0900 Subject: [PATCH 2/2] Fix model output when training --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 52566ef..7a3a5ec 100644 --- a/model.py +++ b/model.py @@ -40,7 +40,7 @@ def forward(self, x, length=None, writer=None): _, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token) decoder, w = self.Decoder(x, mask=look_ahead_mask) fc = self.fc(decoder) - return fc.contiguous() if self.training else fc.contiguous(), [weight.contiguous() for weight in w] + return fc.contiguous() if self.training else (fc.contiguous(), [weight.contiguous() for weight in w]) else: return self.generate(x, length, None).contiguous().tolist()