From 13641268b59df5cf90d27b451d87ab58b6a07055 Mon Sep 17 00:00:00 2001 From: Li Dong Date: Thu, 2 Apr 2020 18:40:19 +0800 Subject: [PATCH] s2s CPU --- s2s-ft/s2s_ft/modeling_decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2s-ft/s2s_ft/modeling_decoding.py b/s2s-ft/s2s_ft/modeling_decoding.py index 89dec6565..c235be143 100644 --- a/s2s-ft/s2s_ft/modeling_decoding.py +++ b/s2s-ft/s2s_ft/modeling_decoding.py @@ -1656,7 +1656,7 @@ def get_dup_ngram_candidates(seq, n): forbid_word_mask = torch.tensor( buf_matrix, dtype=log_scores.dtype) forbid_word_mask = torch.reshape( - forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda() + forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device) else: forbid_word_mask = None next_pos += 1